10 Commits

Author SHA1 Message Date
531cd3fe08 style: format Rust workspace 2026-06-18 18:11:58 +08:00
013465fc06 docs: Phase 21 — decode CUDA graph + GPU argmax results
dash5, gpt-oss-20b FP8, warm-server vs llama.cpp MXFP4 (6 reps):
TP=2 TPOT 5.76-5.89 vs 7.42-8.45 ms (xserv 1.26-1.47x), TTFT 2.4x
ahead short/medium; TP=1 5.78-5.95 vs 2.80-3.22 ms (gap 2.5x -> 2.0x,
TTFT now ahead short/medium). GSM8K-50 through the graph path: 94%.
Lesson recorded: graphs bought ~0.6 ms (launches were already hidden
by async execution), the GPU argmax ~1 ms — measure, don't guess.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-12 20:12:37 +08:00
8414f8d1e6 sampling: GPU argmax fast path for greedy decode
sample() at temperature 0 copied the full [seq, 201088] BF16 logits
to the host and scanned them every token (~1 ms/token). Use the
Phase 15 argmax kernel (block reduction + 4-byte D2H) when logits are
BF16 on GPU; bench-gpt-oss's greedy sampler likewise. Exact-tie
logits may break differently than the host scan — greedy trajectories
can legitimately diverge at a tie token (GSM8K unchanged).

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-12 20:12:37 +08:00
34224c7c93 gpt-oss: replay the whole batch=1 decode step as one CUDA graph
Split forward_decode_paged into host bookkeeping (decode_prepare +
ids/pos upload + advance_seq_len) and a pure-GPU decode_core. The
paged-KV and sparse-MoE designs already read every per-step variable
(block table, context lens, expert ids) from stable-address device
buffers, so decode_core captures as-is.

GptOssDecodeGraph captures lazily on the second decode step (the
first eager step warms cuBLAS) after a "retained warmup": the step
runs once with the allocator quarantine on, stocking the pool with a
dedicated block for every allocation so the capture itself never
pool-misses (a cudaMalloc while capturing is illegal — and the
capture's own quarantine is what would otherwise starve the pool).
NCCL all-reduces capture cleanly; TP=2 replays in lockstep.

Wired into tp_engine, bench-gpt-oss, and xserv-chat via
GraphedGptOssDecoder (batch>1 falls back to eager;
XSERV_DECODE_GRAPH=0 disables). Greedy tokens identical to eager.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-12 20:12:37 +08:00
4088f49b7d cuda: infrastructure for whole-step CUDA graph capture
- Thread-local launch stream (xserv_cuda::stream): every kernel
  wrapper, cublasSetStream, and NCCL collective now launches on
  current_stream_raw() — the legacy null stream by default (behavior
  unchanged), or the capture stream installed via push_stream during
  graph capture. Capture is impossible on the legacy stream.
- Allocator retain mode: blocks freed inside a retain window are
  quarantined (RetainedBlocks) instead of pooled, so an instantiated
  graph keeps exclusive ownership of every intermediate buffer it
  references across replays.
- Capture mode GLOBAL -> THREAD_LOCAL: concurrent TP rank threads
  must not poison each other's captures with their own cudaMallocs.
- embedding_device_ids / rope_inplace_device_pos: variants reading
  token ids / positions from persistent device buffers, replacing the
  per-call host upload that a captured region cannot contain.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-12 20:12:37 +08:00
2a92f268a9 docs: fill the Phase 19 gap, refresh README/roadmap to actual state
- docs/19-gpt-oss-moe.md: the numbered series jumped 18->20; write up
  gpt-oss arch deltas, harmony pitfalls, and the two CUDA debugging
  postmortems (fully-masked-tile NaN in flash-attention sinks;
  pre-__syncthreads early return reading uninitialized smem in the
  decode GEMV) — the highest-value learning content of that phase.
- README: models/perf/capabilities were frozen at the Qwen3-only era;
  now lists gpt-oss MoE, TP/PP, FP8/MXFP4, sparse MoE, and the
  llama.cpp standing.
- Roadmap: record where reality diverged from the plan at Phase 18+,
  add milestone entries and the ranked next-phase candidates
  (21 CUDA-graph MoE decode, 22 non-expert quant, 23 sparse prefill).
- sparse-moe benchmark doc: post-review-fix numbers.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-12 17:02:59 +08:00
5343391dbd review cleanups: pp+gpt-oss guard, sparse GEMV asserts, warnings
- --pp with gpt-oss now fails with a clear message instead of a
  cryptic missing-weight panic inside the Qwen3-only PP engine.
- Sparse GEMV wrappers assert K%16==0 (FP8) / K%32==0 (MXFP4) — the
  uint4-vectorized kernels would silently drop a tail otherwise.
- Document the topk_ids buffer holding i32 under an F32 dtype label
  (DType has no I32).
- Drop unused imports/locals and the cuBLASLt scale-mode constants
  orphaned by the strided-batched FP8 rework (e631a71).

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-12 17:02:59 +08:00
1897b2e17a gpt-oss: drop debug syncs from forward; GPU broadcast bias-add
Decode carried three leftover cudaDeviceSynchronize (prefill one) from
NaN debugging — the Qwen3 path has none and the logits D2H in sample()
already orders against the null stream.

add_bias for rows>1 round-tripped the bias through the CPU (D2H + host
tile loop + H2D) on every call — 96 times per prefill across q/k/v/o.
Replace with a bias_add_2d broadcast kernel.

dash5, FP8 TP=2, warm-server: TTFT 35/49/94 -> 29/42/79 ms
(short/medium/long), TPOT 7.19-7.32 -> 6.99-7.21 ms. Greedy tokens
unchanged; GSM8K-50 94%.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-12 17:02:59 +08:00
63f5599717 server: serve gpt-oss on a single GPU via the TP engine (world=1)
gpt-oss has no single-GPU engine path, so --tp 1 fell through to the
Qwen3-only engine and every request 503'd. Route gpt_oss to run_tp
even at tp=1: NCCL world-1 init works and all_reduce already no-ops
(bench-gpt-oss --tp 1 exercised this path). Quantized gpt-oss (22 GB
FP8 / 13 GB MXFP4) now serves on one 32 GB 5090.

Also fix eval_gsm8k_fast.py --gpu to accept a device list ("2,3"):
it was type=int, so any --tp 2 run pinned CUDA_VISIBLE_DEVICES to one
GPU and rank 1's set_device panicked while rank 0 spun in NCCL init.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-12 16:29:10 +08:00
fb20178992 moe: sparse top-k decode — compute only routed experts (1.8x, beats llama TP=2)
Dense MoE replicated x across all 16 local experts and ran the full
batched GEMM, reading every expert's weights per token; the weighted
sum then discarded 12 of 16 results. Decode is memory-bound, so this
was ~8x wasted expert bytes — the entire decode gap vs llama.cpp.

New fused expert-indexed GEMVs (csrc/moe/moe_sparse.cu) read
topk_ids on-device (no host sync) and early-return block-uniformly
for experts other ranks own. FP8 runs W8A16 (activations stay BF16 —
tensor cores are irrelevant at M=1, and activation quantization error
disappears); MXFP4 runs W4A16. Per-expert bias + scale fused into the
GEMV epilogue; slot-indexed weighted sum skips (never multiplies)
unwritten non-local slots. Dense path retained for num_tokens > 8
(prefill) and via XSERV_DENSE_MOE=1 for A/B.

dash5 (RTX 5090), gpt-oss-20b FP8, TP=2: decode TPOT 13.9 -> 7.6 ms.
Warm-server vs llama.cpp MXFP4 TP=2: TPOT 7.19-7.32 vs 7.54-8.42 ms —
first config where xserv wins decode outright. GSM8K-100: 96% (dense
FP8: 91%). llama TP=1 (2.9 ms) remains ahead: next levers are decode
CUDA graphs, non-expert quantization, sparse prefill (docs/20).

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-12 16:29:10 +08:00
69 changed files with 5499 additions and 1268 deletions

View File

@@ -3,18 +3,24 @@
> 从零用 **Rust + CUDA** 构建的 LLM 推理引擎,目标是吃透 LLM Serving 全栈技术。
xserv 不依赖 PyTorch / vLLM / TensorRT 等现成框架自己实现了张量抽象、CUDA kernel、
分词器、模型前向、KV cache、调度器和 OpenAI 兼容的 HTTP 服务。当前在单张 RTX 5090 上可以
跑通 **Qwen3-8B**BF16并提供一套与 **llama.cpp** 对比正确性和性能的标准 benchmark。
分词器、模型前向、KV cache、调度器和 OpenAI 兼容的 HTTP 服务。支持 **Qwen3-8B**BF16
**gpt-oss-20b**MoEBF16/FP8/MXFP4 量化),多卡 TP/PP并提供一套与 **llama.cpp**
对比正确性和性能的标准 benchmark。
## 现状一览
- **模型**GPT-2124M、Qwen3-8BBF16
- **性能**RTX 5090Qwen3-8B BF16贪心解码单流**56 tok/s**,约为 HF transformers 的 1.4×、llama.cpp 的 ~0.6×
- **精度**:在 AIME 2025 / GSM8K 上与 llama.cpp 同权重对比基本持平(数值保真度验证通过
- **服务**OpenAI 兼容 `/v1/chat/completions`,支持 SSE 流式输出
- **关键能力**:自写 GEMM / Flash-Attention 2(SM120) / Paged-Attention kernel、
分页 KV cache**CPU 换出/换入** 弹性显存、连续批处理continuous batching
CUDA Graph 解码、按显存自适应的 KV 池
- **模型**GPT-2124M、Qwen3-8BBF16、gpt-oss-20b32 专家 top-4 MoEharmony 格式)
- **性能**RTX 5090贪心,单流):
- Qwen3-8B BF16 单卡:约 56 tok/sHF transformers 的 1.4×
- gpt-oss-20b FP8 稀疏 MoE + CUDA Graph decode**TPOT 5.8ms~172 tok/s
TP=1/2 同速)**;同配置 TP=2 全面快于 llama.cpp1.26-1.47×llama
单卡模式2.8ms)仍领先,差距 2.0×
- **精度**GSM8K 全量与 llama.cpp 同权重持平94.5% vs 94.4%FP8/MXFP4 量化无回归
- **服务**OpenAI 兼容 `/v1/chat/completions`SSE 流式gpt-oss 量化后可**单卡 32GB 服务**
- **关键能力**:自写 GEMM / Flash-Attention 2(SM120含 attention sinks + sliding window) /
Paged-Attention kernel、分页 KV cache**CPU 换出/换入**)、连续批处理、
CUDA Graph 解码Qwen3 单卡 + gpt-oss 全路径整图回放)、**Tensor/Pipeline 并行**NCCLTP=1/2/4、PP=2/4
**FP8 W8A8 / MXFP4 W4A16 量化**、**稀疏 top-k MoE decode**(只算被路由的专家)
> 这是一个以学习为主的项目,逐 Phase 推进,每步都做数值/端到端验证。
@@ -26,16 +32,19 @@ xserv/
│ ├── gemm/ # GEMM (naive / tiled / gemv)
│ ├── attention/ # Flash-Attention 2 (SM120)、Paged-Attention、causal mask
│ ├── normalization/ # LayerNorm / RMSNorm
│ ├── activation/ # GELU / SiLU
│ ├── activation/ # GELU / SiLU / gpt-oss GLU
│ ├── embedding/ # embedding lookup / RoPE / transpose
│ ├── moe/ # MoE top-k 路由、稀疏专家 GEMV、加权求和
│ ├── quantization/ # FP8 量化/反量化、cuBLASLt FP8 GEMM、MXFP4 GEMV
│ └── reduce/ # softmax
├── crates/
│ ├── xserv-cuda/ # CUDA FFI、Stream、显存分配器、Pinned 内存、CUDA Graph
│ ├── xserv-tensor/ # Tensor 类型strided 布局、BF16/F16/F32、CPU↔GPU
│ ├── xserv-kernels/ # kernel registry自写 kernel + cuBLAS 可切换)
│ ├── xserv-tokenizer/ # BPE 分词器
│ ├── xserv-model/ # 模型定义GPT-2 / Qwen3、权重加载、KV cache、采样
── xserv-server/ # tokio + axum HTTP 服务、调度器
│ ├── xserv-distributed/ # NCCL FFI、TP 上下文AllReduce
── xserv-model/ # 模型定义GPT-2 / Qwen3 / gpt-oss MoE、权重加载、KV cache、采样
│ └── xserv-server/ # tokio + axum HTTP 服务、调度器、TP/PP 引擎
├── tools/ # 辅助脚本 + benchmark 套件(见下)
└── docs/ # 每个 Phase 的设计文档 + benchmark 报告
```
@@ -185,12 +194,14 @@ GSM8K 12 个格子全是 29/30xserv 与 llama.cpp 完全一致AIME 的 ±1
## 路线图(节选)
已完成 Phase 018CUDA 基础设施 → Tensor → GEMM → Transformer kernels → Attention →
已完成 Phase 021CUDA 基础设施 → Tensor → GEMM → Transformer kernels → Attention →
模型加载 → 分词器 → GPT-2 → KV cache → Qwen3-8B → Paged Attention → 连续批处理 →
HTTP API → Flash Attention 2 → 性能优化 → **张量并行TP****流水线并行PP**
HTTP API → Flash Attention 2 → 性能优化 → **张量并行TP****流水线并行PP**
**gpt-oss MoE + FP8/MXFP4 量化****稀疏 top-k MoE decode****decode CUDA Graph 整图回放**
并加入了 **llama.cpp 对比基准****KV CPU 换出** 等基础设施。
后续方向:PP microbatch/1F1B 流水线重叠吞吐收益、2D TP×PP、投机解码、量化FP8 / INT8、多模态。
后续方向:非专家权重量化lm_head/qkv/o、稀疏 prefillgrouped GEMM、server 侧 harmony
channel 分离、PP microbatch/1F1B、投机解码、多模态。详见 `docs/00-roadmap.md` 的实际进展记录。
## 许可

View File

@@ -111,6 +111,22 @@ pub fn cached_trim() {
/// Called from `GpuBuffer::Drop` for pooled buffers. Takes raw pointer
/// and size to avoid re-triggering Drop.
pub fn return_to_pool(ptr: *mut u8, len: usize) {
// During CUDA graph capture, buffers freed by the captured code are
// quarantined instead of pooled: the instantiated graph references their
// addresses on every replay, so they must never be handed to another
// consumer for as long as the graph lives.
let quarantined = RETAINED.with(|cell| {
let mut r = cell.borrow_mut();
if let Some(list) = r.as_mut() {
list.push((ptr, len));
true
} else {
false
}
});
if quarantined {
return;
}
ALLOCATOR.with(|cell| {
let mut alloc = cell.borrow_mut();
let bucket = bucket_size(len);
@@ -119,6 +135,44 @@ pub fn return_to_pool(ptr: *mut u8, len: usize) {
});
}
thread_local! {
static RETAINED: RefCell<Option<Vec<(*mut u8, usize)>>> = const { RefCell::new(None) };
}
/// Buffers freed while a retain window was active. Holding this keeps their
/// memory out of the pool; dropping it returns the blocks (on the owning
/// thread) for reuse.
pub struct RetainedBlocks(Vec<(*mut u8, usize)>);
impl Drop for RetainedBlocks {
fn drop(&mut self) {
for (ptr, len) in self.0.drain(..) {
return_to_pool(ptr, len);
}
}
}
/// Start quarantining buffers freed on this thread (see `return_to_pool`).
/// Must be paired with `end_retain` on the same thread; nesting unsupported.
pub fn begin_retain() {
RETAINED.with(|cell| {
let mut r = cell.borrow_mut();
assert!(r.is_none(), "begin_retain: retain window already active");
*r = Some(Vec::new());
});
}
/// Stop quarantining and hand the quarantined blocks to the caller.
pub fn end_retain() -> RetainedBlocks {
RETAINED.with(|cell| {
let list = cell
.borrow_mut()
.take()
.expect("end_retain without begin_retain");
RetainedBlocks(list)
})
}
/// Round up to next power-of-2, minimum 512 bytes.
fn bucket_size(size: usize) -> usize {
let min = 512;

View File

@@ -48,9 +48,7 @@ pub fn device_info(device: u32) -> Result<DeviceInfo> {
// Heap-allocate oversized buffer for cudaDeviceProp (layout varies by CUDA version).
// CUDA 12.x struct is ~5-6 KB; use 32 KB to guard against future growth.
let mut prop_buf = vec![0u8; 32768];
error::check(unsafe {
ffi::cudaGetDeviceProperties(prop_buf.as_mut_ptr(), device as i32)
})?;
error::check(unsafe { ffi::cudaGetDeviceProperties(prop_buf.as_mut_ptr(), device as i32) })?;
// Name is always the first field: char[256].
let name = unsafe { CStr::from_ptr(prop_buf.as_ptr() as *const c_char) }
.to_string_lossy()

View File

@@ -15,6 +15,7 @@ pub const CUDA_ERROR_OUT_OF_MEMORY: i32 = 2;
/// cudaStreamCaptureMode::cudaStreamCaptureModeGlobal
pub const CUDA_STREAM_CAPTURE_MODE_GLOBAL: i32 = 0;
pub const CUDA_STREAM_CAPTURE_MODE_THREAD_LOCAL: i32 = 1;
unsafe extern "C" {
// --- Device ---
@@ -63,11 +64,5 @@ unsafe extern "C" {
pub fn cudaGraphExecDestroy(graph_exec: CudaGraphExec) -> i32;
// --- Our test kernel ---
pub fn launch_vecadd_f32(
a: *const f32,
b: *const f32,
c: *mut f32,
n: i32,
stream: CudaStream,
);
pub fn launch_vecadd_f32(a: *const f32, b: *const f32, c: *mut f32, n: i32, stream: CudaStream);
}

View File

@@ -50,31 +50,25 @@ impl CudaGraph {
pub fn begin_capture(&mut self, stream: &CudaStream) -> Result<()> {
// If we have an old graph, destroy it first
self.destroy_inner();
// THREAD_LOCAL: only "potentially unsafe" CUDA calls (cudaMalloc etc.)
// made by THIS thread invalidate the capture. With GLOBAL mode, TP rank
// threads capturing concurrently would poison each other's captures.
error::check(unsafe {
ffi::cudaStreamBeginCapture(
stream.as_raw(),
ffi::CUDA_STREAM_CAPTURE_MODE_GLOBAL,
)
ffi::cudaStreamBeginCapture(stream.as_raw(), ffi::CUDA_STREAM_CAPTURE_MODE_THREAD_LOCAL)
})
}
/// End capture and instantiate the executable graph.
pub fn end_capture(&mut self, stream: &CudaStream) -> Result<()> {
error::check(unsafe {
ffi::cudaStreamEndCapture(stream.as_raw(), &mut self.graph)
})?;
error::check(unsafe {
ffi::cudaGraphInstantiate(&mut self.exec, self.graph, 0)
})
error::check(unsafe { ffi::cudaStreamEndCapture(stream.as_raw(), &mut self.graph) })?;
error::check(unsafe { ffi::cudaGraphInstantiate(&mut self.exec, self.graph, 0) })
}
/// Replay the captured graph on `stream`.
/// Panics if no graph has been captured yet.
pub fn launch(&self, stream: &CudaStream) -> Result<()> {
assert!(self.is_ready(), "CudaGraph::launch called before capture");
error::check(unsafe {
ffi::cudaGraphLaunch(self.exec, stream.as_raw())
})
error::check(unsafe { ffi::cudaGraphLaunch(self.exec, stream.as_raw()) })
}
fn destroy_inner(&mut self) {

View File

@@ -11,4 +11,4 @@ pub use device::DeviceInfo;
pub use error::{CudaError, Result};
pub use graph::CudaGraph;
pub use memory::{GpuBuffer, PinnedBuffer};
pub use stream::CudaStream;
pub use stream::{CudaStream, StreamGuard, current_stream_raw, push_stream};

View File

@@ -22,7 +22,12 @@ impl GpuBuffer {
assert!(len > 0, "cannot allocate 0 bytes on GPU");
let mut ptr = std::ptr::null_mut();
error::check(unsafe { ffi::cudaMalloc(&mut ptr, len) })?;
Ok(Self { ptr, len, owned: true, pooled: false })
Ok(Self {
ptr,
len,
owned: true,
pooled: false,
})
}
/// Mark this buffer as pooled (returned to caching allocator on drop)
@@ -92,9 +97,7 @@ impl GpuBuffer {
/// Copy from another GPU buffer (D2D).
pub fn copy_from_device(&mut self, src: &GpuBuffer) -> Result<()> {
let n = src.len.min(self.len);
error::check(unsafe {
ffi::cudaMemcpy(self.ptr, src.ptr, n, ffi::CUDA_MEMCPY_D2D)
})
error::check(unsafe { ffi::cudaMemcpy(self.ptr, src.ptr, n, ffi::CUDA_MEMCPY_D2D) })
}
/// Fill buffer with zeros.
@@ -103,7 +106,13 @@ impl GpuBuffer {
}
/// Copy `count` bytes from `src` buffer at `src_offset` to this buffer at `dst_offset`.
pub fn copy_from_device_at(&mut self, src: &GpuBuffer, src_offset: usize, dst_offset: usize, count: usize) -> Result<()> {
pub fn copy_from_device_at(
&mut self,
src: &GpuBuffer,
src_offset: usize,
dst_offset: usize,
count: usize,
) -> Result<()> {
assert!(src_offset + count <= src.len);
assert!(dst_offset + count <= self.len);
error::check(unsafe {
@@ -117,7 +126,14 @@ impl GpuBuffer {
}
/// Async copy `count` bytes from `src` at `src_offset` to `self` at `dst_offset` on `stream`.
pub fn copy_from_device_at_async(&mut self, src: &GpuBuffer, src_offset: usize, dst_offset: usize, count: usize, stream: &CudaStream) -> Result<()> {
pub fn copy_from_device_at_async(
&mut self,
src: &GpuBuffer,
src_offset: usize,
dst_offset: usize,
count: usize,
stream: &CudaStream,
) -> Result<()> {
assert!(src_offset + count <= src.len);
assert!(dst_offset + count <= self.len);
error::check(unsafe {
@@ -161,9 +177,7 @@ impl GpuBuffer {
/// Async zero fill on stream.
pub fn zero_async(&mut self, stream: &CudaStream) -> Result<()> {
error::check(unsafe {
ffi::cudaMemsetAsync(self.ptr, 0, self.len, stream.as_raw())
})
error::check(unsafe { ffi::cudaMemsetAsync(self.ptr, 0, self.len, stream.as_raw()) })
}
/// Consume the buffer without freeing GPU memory. Returns the raw pointer and length.
@@ -178,7 +192,12 @@ impl GpuBuffer {
/// Reconstruct a GpuBuffer from a raw pointer + length.
/// Safety: ptr must have been allocated with cudaMalloc, len must be correct.
pub unsafe fn from_raw(ptr: *mut u8, len: usize) -> Self {
Self { ptr, len, owned: true, pooled: false }
Self {
ptr,
len,
owned: true,
pooled: false,
}
}
/// Create a non-owning view of GPU memory. Dropping this buffer does NOT
@@ -189,7 +208,12 @@ impl GpuBuffer {
/// `ptr` must point to a valid GPU allocation of at least `len` bytes that
/// will remain live for the lifetime of the returned `GpuBuffer`.
pub unsafe fn borrow_raw(ptr: *mut u8, len: usize) -> Self {
Self { ptr, len, owned: false, pooled: false }
Self {
ptr,
len,
owned: false,
pooled: false,
}
}
}

View File

@@ -31,3 +31,39 @@ impl Drop for CudaStream {
// Can move across threads, but not shared without synchronization
unsafe impl Send for CudaStream {}
// --- Thread-local launch stream -------------------------------------------
//
// Every kernel wrapper in xserv-kernels launches on `current_stream_raw()`,
// which defaults to the legacy null stream (the historical behavior). CUDA
// graph capture requires work to be issued on an explicit stream, so capture
// code installs its stream here for the duration of the captured region via
// `push_stream` / `StreamGuard`.
use std::cell::Cell;
thread_local! {
static CURRENT_STREAM: Cell<ffi::CudaStream> = const { Cell::new(std::ptr::null_mut()) };
}
/// The stream kernel launches on this thread should use (null = legacy default).
pub fn current_stream_raw() -> ffi::CudaStream {
CURRENT_STREAM.with(|c| c.get())
}
/// RAII guard that installs a launch stream for the current thread and
/// restores the previous one on drop.
pub struct StreamGuard {
prev: ffi::CudaStream,
}
pub fn push_stream(stream: &CudaStream) -> StreamGuard {
let prev = CURRENT_STREAM.with(|c| c.replace(stream.as_raw()));
StreamGuard { prev }
}
impl Drop for StreamGuard {
fn drop(&mut self) {
CURRENT_STREAM.with(|c| c.set(self.prev));
}
}

View File

@@ -14,7 +14,10 @@ fn test_device_info() {
info.compute_major, info.compute_minor
);
println!(" SM Count: {}", info.sm_count);
println!(" Shared Mem/Block: {} KB", info.shared_mem_per_block / 1024);
println!(
" Shared Mem/Block: {} KB",
info.shared_mem_per_block / 1024
);
println!(" Warp Size: {}", info.warp_size);
println!(" Max Threads/Block: {}", info.max_threads_per_block);
@@ -145,7 +148,11 @@ fn test_caching_allocator() {
// Second allocation of same size: should hit cache
let _buf2 = alloc.alloc(1024).unwrap();
assert_eq!(alloc.stats().cuda_malloc_count, 1, "should reuse cached buffer");
assert_eq!(
alloc.stats().cuda_malloc_count,
1,
"should reuse cached buffer"
);
assert_eq!(alloc.stats().cache_hit_count, 1);
}
@@ -198,11 +205,17 @@ fn test_async_copy() {
}
let mut gpu = GpuBuffer::alloc(4096).unwrap();
unsafe { gpu.copy_from_host_async(pinned.as_slice(), &stream).unwrap() };
unsafe {
gpu.copy_from_host_async(pinned.as_slice(), &stream)
.unwrap()
};
stream.synchronize().unwrap();
let mut out_pinned = PinnedBuffer::alloc(4096).unwrap();
unsafe { gpu.copy_to_host_async(out_pinned.as_mut_slice(), &stream).unwrap() };
unsafe {
gpu.copy_to_host_async(out_pinned.as_mut_slice(), &stream)
.unwrap()
};
stream.synchronize().unwrap();
assert_eq!(pinned.as_slice(), out_pinned.as_slice());

View File

@@ -34,7 +34,12 @@ pub const NCCL_SUCCESS: i32 = 0;
unsafe extern "C" {
pub fn ncclGetUniqueId(uid: *mut NcclUniqueId) -> i32;
// ncclUniqueId is passed BY VALUE (a 128-byte struct) per the NCCL ABI.
pub fn ncclCommInitRank(comm: *mut NcclComm, nranks: i32, commid: NcclUniqueId, rank: i32) -> i32;
pub fn ncclCommInitRank(
comm: *mut NcclComm,
nranks: i32,
commid: NcclUniqueId,
rank: i32,
) -> i32;
pub fn ncclCommDestroy(comm: NcclComm) -> i32;
pub fn ncclAllReduce(
sendbuff: *const c_void,
@@ -78,5 +83,10 @@ pub fn err_string(result: i32) -> String {
}
pub fn check(result: i32, what: &str) {
assert_eq!(result, NCCL_SUCCESS, "{what} failed: {}", err_string(result));
assert_eq!(
result,
NCCL_SUCCESS,
"{what} failed: {}",
err_string(result)
);
}

View File

@@ -9,15 +9,18 @@ pub mod ffi;
use std::ffi::c_void;
use ffi::{NcclComm, NcclUniqueId};
use xserv_cuda::device;
use xserv_cuda::GpuBuffer;
use xserv_cuda::device;
pub use ffi::NcclUniqueId as UniqueId;
/// The CUDA "null" (default) stream. The model's kernels and cuBLAS calls run
/// on it, so issuing NCCL on the same stream keeps AllReduce correctly ordered
/// after the producing matmul and before the consuming kernel — no extra sync.
const NULL_STREAM: xserv_cuda::ffi::CudaStream = std::ptr::null_mut();
/// NCCL is issued on the thread's current launch stream (legacy null stream
/// by default, the capture stream during CUDA graph capture). The model's
/// kernels run on the same stream, so AllReduce stays correctly ordered after
/// the producing matmul and before the consuming kernel — no extra sync.
fn launch_stream() -> xserv_cuda::ffi::CudaStream {
xserv_cuda::stream::current_stream_raw()
}
/// Generate a unique id on one rank (typically rank 0) and broadcast the bytes
/// to all ranks out-of-band (e.g. via a shared variable across threads).
@@ -52,7 +55,12 @@ impl TpContext {
"ncclCommInitRank",
);
ffi::check(unsafe { ffi::ncclGroupEnd() }, "ncclGroupEnd(init)");
Self { rank, world, device, comm }
Self {
rank,
world,
device,
comm,
}
}
/// In-place AllReduce(sum) over `count` BF16 elements in `buf`.
@@ -80,7 +88,7 @@ impl TpContext {
ffi::NCCL_BF16,
ffi::NCCL_SUM,
self.comm,
NULL_STREAM,
launch_stream(),
)
},
"ncclAllReduce",
@@ -124,7 +132,12 @@ impl PpContext {
"ncclCommInitRank",
);
ffi::check(unsafe { ffi::ncclGroupEnd() }, "ncclGroupEnd(init)");
Self { stage, world, device, comm }
Self {
stage,
world,
device,
comm,
}
}
/// Send `count` BF16 elements at `ptr` to `peer`, on the null stream so it is
@@ -135,7 +148,16 @@ impl PpContext {
/// `ptr` must point to at least `count` BF16 elements of valid device memory.
pub fn send_bf16_ptr(&self, ptr: *const c_void, count: usize, peer: usize) {
ffi::check(
unsafe { ffi::ncclSend(ptr, count, ffi::NCCL_BF16, peer as i32, self.comm, NULL_STREAM) },
unsafe {
ffi::ncclSend(
ptr,
count,
ffi::NCCL_BF16,
peer as i32,
self.comm,
launch_stream(),
)
},
"ncclSend",
);
}
@@ -146,7 +168,16 @@ impl PpContext {
/// `ptr` must point to at least `count` BF16 elements of valid device memory.
pub fn recv_bf16_ptr(&self, ptr: *mut c_void, count: usize, peer: usize) {
ffi::check(
unsafe { ffi::ncclRecv(ptr, count, ffi::NCCL_BF16, peer as i32, self.comm, NULL_STREAM) },
unsafe {
ffi::ncclRecv(
ptr,
count,
ffi::NCCL_BF16,
peer as i32,
self.comm,
launch_stream(),
)
},
"ncclRecv",
);
}

View File

@@ -2,8 +2,8 @@
use half::bf16;
use std::thread;
use xserv_cuda::{device, GpuBuffer};
use xserv_distributed::{get_unique_id, TpContext};
use xserv_cuda::{GpuBuffer, device};
use xserv_distributed::{TpContext, get_unique_id};
#[test]
fn allreduce_two_gpu_sum() {
@@ -25,9 +25,7 @@ fn allreduce_two_gpu_sum() {
// Rank r fills its buffer with (r + 1).
let val = bf16::from_f32((rank + 1) as f32);
let host = vec![val; n];
let src = unsafe {
std::slice::from_raw_parts(host.as_ptr() as *const u8, n * 2)
};
let src = unsafe { std::slice::from_raw_parts(host.as_ptr() as *const u8, n * 2) };
let mut buf = GpuBuffer::alloc(n * 2).unwrap();
buf.copy_from_host(src).unwrap();

View File

@@ -6,8 +6,8 @@
use half::bf16;
use std::ffi::c_void;
use std::thread;
use xserv_cuda::{device, GpuBuffer};
use xserv_distributed::{get_unique_id, PpContext};
use xserv_cuda::{GpuBuffer, device};
use xserv_distributed::{PpContext, get_unique_id};
#[test]
fn pp_send_recv_two_stages() {
@@ -30,7 +30,8 @@ fn pp_send_recv_two_stages() {
if stage == 0 {
// Fill with a known pattern and send to stage 1.
let host: Vec<bf16> = (0..n).map(|i| bf16::from_f32((i % 97) as f32)).collect();
let src = unsafe { std::slice::from_raw_parts(host.as_ptr() as *const u8, n * 2) };
let src =
unsafe { std::slice::from_raw_parts(host.as_ptr() as *const u8, n * 2) };
buf.copy_from_host(src).unwrap();
pp.send_bf16_ptr(buf.as_mut_ptr() as *const c_void, n, 1);
device::synchronize().unwrap();

View File

@@ -31,6 +31,7 @@ fn main() {
.file("../../csrc/attention/paged_attention.cu")
.file("../../csrc/attention/reshape_and_cache.cu")
.file("../../csrc/moe/moe_kernels.cu")
.file("../../csrc/moe/moe_sparse.cu")
.file("../../csrc/quantization/dequant_fp8.cu")
.file("../../csrc/quantization/quantize_fp8.cu")
.file("../../csrc/quantization/mxfp4_gemm.cu")

View File

@@ -6,76 +6,220 @@ unsafe extern "C" {
fn launch_gelu_bf16(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_silu_f32(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_silu_bf16(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_scale_f32(x: *const c_void, out: *mut c_void, scale: f32, n: i32, stream: *mut c_void);
fn launch_scale_bf16(x: *const c_void, out: *mut c_void, scale: f32, n: i32, stream: *mut c_void);
fn launch_add_f32(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_add_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_mul_f32(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_mul_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_silu_mul_bf16(gate: *const c_void, up: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_gpt_oss_glu_bf16(gate_up: *const c_void, out: *mut c_void, n_elements: i32,
alpha: f32, limit: f32, stream: *mut c_void);
fn launch_scale_f32(
x: *const c_void,
out: *mut c_void,
scale: f32,
n: i32,
stream: *mut c_void,
);
fn launch_scale_bf16(
x: *const c_void,
out: *mut c_void,
scale: f32,
n: i32,
stream: *mut c_void,
);
fn launch_add_f32(
a: *const c_void,
b: *const c_void,
out: *mut c_void,
n: i32,
stream: *mut c_void,
);
fn launch_add_bf16(
a: *const c_void,
b: *const c_void,
out: *mut c_void,
n: i32,
stream: *mut c_void,
);
fn launch_mul_f32(
a: *const c_void,
b: *const c_void,
out: *mut c_void,
n: i32,
stream: *mut c_void,
);
fn launch_mul_bf16(
a: *const c_void,
b: *const c_void,
out: *mut c_void,
n: i32,
stream: *mut c_void,
);
fn launch_silu_mul_bf16(
gate: *const c_void,
up: *const c_void,
out: *mut c_void,
n: i32,
stream: *mut c_void,
);
fn launch_gpt_oss_glu_bf16(
gate_up: *const c_void,
out: *mut c_void,
n_elements: i32,
alpha: f32,
limit: f32,
stream: *mut c_void,
);
fn launch_bias_add_2d_bf16(
x: *const c_void,
bias: *const c_void,
out: *mut c_void,
rows: i32,
cols: i32,
stream: *mut c_void,
);
}
fn dispatch_unary(x: &Tensor, f32_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void),
bf16_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void)) -> Tensor {
fn dispatch_unary(
x: &Tensor,
f32_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void),
bf16_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void),
) -> Tensor {
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
let n = x.numel();
assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)");
assert!(
n <= i32::MAX as usize,
"tensor too large for i32 kernel param ({n} elements)"
);
let n = n as i32;
unsafe {
match x.dtype() {
DType::F32 => f32_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
DType::BF16 => bf16_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
DType::F32 => f32_fn(
x.data_ptr() as _,
out.data_ptr() as *mut c_void,
n,
xserv_cuda::current_stream_raw(),
),
DType::BF16 => bf16_fn(
x.data_ptr() as _,
out.data_ptr() as *mut c_void,
n,
xserv_cuda::current_stream_raw(),
),
_ => panic!("unsupported dtype"),
}
}
out
}
fn dispatch_binary(a: &Tensor, b: &Tensor,
f32_fn: unsafe extern "C" fn(*const c_void, *const c_void, *mut c_void, i32, *mut c_void),
bf16_fn: unsafe extern "C" fn(*const c_void, *const c_void, *mut c_void, i32, *mut c_void)) -> Tensor {
fn dispatch_binary(
a: &Tensor,
b: &Tensor,
f32_fn: unsafe extern "C" fn(*const c_void, *const c_void, *mut c_void, i32, *mut c_void),
bf16_fn: unsafe extern "C" fn(*const c_void, *const c_void, *mut c_void, i32, *mut c_void),
) -> Tensor {
assert_eq!(a.shape(), b.shape());
assert!(a.is_contiguous() && b.is_contiguous());
assert!(matches!(a.device(), Device::Cuda(_)));
assert_eq!(a.dtype(), b.dtype());
let out = Tensor::empty(a.shape(), a.dtype(), a.device());
let n = a.numel();
assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)");
assert!(
n <= i32::MAX as usize,
"tensor too large for i32 kernel param ({n} elements)"
);
let n = n as i32;
unsafe {
match a.dtype() {
DType::F32 => f32_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
DType::BF16 => bf16_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
DType::F32 => f32_fn(
a.data_ptr() as _,
b.data_ptr() as _,
out.data_ptr() as *mut c_void,
n,
xserv_cuda::current_stream_raw(),
),
DType::BF16 => bf16_fn(
a.data_ptr() as _,
b.data_ptr() as _,
out.data_ptr() as *mut c_void,
n,
xserv_cuda::current_stream_raw(),
),
_ => panic!("unsupported dtype"),
}
}
out
}
pub fn gelu(x: &Tensor) -> Tensor { dispatch_unary(x, launch_gelu_f32, launch_gelu_bf16) }
pub fn silu(x: &Tensor) -> Tensor { dispatch_unary(x, launch_silu_f32, launch_silu_bf16) }
pub fn gelu(x: &Tensor) -> Tensor {
dispatch_unary(x, launch_gelu_f32, launch_gelu_bf16)
}
pub fn silu(x: &Tensor) -> Tensor {
dispatch_unary(x, launch_silu_f32, launch_silu_bf16)
}
pub fn scale(x: &Tensor, scale_val: f32) -> Tensor {
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
let n = x.numel();
assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)");
assert!(
n <= i32::MAX as usize,
"tensor too large for i32 kernel param ({n} elements)"
);
let n = n as i32;
unsafe {
match x.dtype() {
DType::F32 => launch_scale_f32(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, std::ptr::null_mut()),
DType::BF16 => launch_scale_bf16(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, std::ptr::null_mut()),
DType::F32 => launch_scale_f32(
x.data_ptr() as _,
out.data_ptr() as *mut c_void,
scale_val,
n,
xserv_cuda::current_stream_raw(),
),
DType::BF16 => launch_scale_bf16(
x.data_ptr() as _,
out.data_ptr() as *mut c_void,
scale_val,
n,
xserv_cuda::current_stream_raw(),
),
_ => panic!("unsupported dtype for scale"),
}
}
out
}
pub fn add(a: &Tensor, b: &Tensor) -> Tensor { dispatch_binary(a, b, launch_add_f32, launch_add_bf16) }
pub fn mul(a: &Tensor, b: &Tensor) -> Tensor { dispatch_binary(a, b, launch_mul_f32, launch_mul_bf16) }
pub fn add(a: &Tensor, b: &Tensor) -> Tensor {
dispatch_binary(a, b, launch_add_f32, launch_add_bf16)
}
pub fn mul(a: &Tensor, b: &Tensor) -> Tensor {
dispatch_binary(a, b, launch_mul_f32, launch_mul_bf16)
}
/// Row-broadcast bias add: out[r, c] = x[r, c] + bias[c] (BF16 only).
pub fn bias_add_2d(x: &Tensor, bias: &Tensor) -> Tensor {
assert_eq!(x.ndim(), 2);
assert_eq!(bias.ndim(), 1);
assert_eq!(x.dtype(), DType::BF16);
assert_eq!(bias.dtype(), DType::BF16);
assert!(x.is_contiguous() && bias.is_contiguous());
assert!(matches!(x.device(), Device::Cuda(_)));
let rows = x.shape()[0];
let cols = x.shape()[1];
assert_eq!(
bias.shape()[0],
cols,
"bias size {} != cols {cols}",
bias.shape()[0]
);
assert!(rows * cols <= i32::MAX as usize);
let out = Tensor::empty(&[rows, cols], DType::BF16, x.device());
unsafe {
launch_bias_add_2d_bf16(
x.data_ptr() as _,
bias.data_ptr() as _,
out.data_ptr() as *mut c_void,
rows as i32,
cols as i32,
xserv_cuda::current_stream_raw(),
);
}
out
}
/// Fused SiLU×Mul: out = silu(gate) * up (BF16 only)
/// Saves one HBM read + one HBM write compared to separate silu + mul.
@@ -86,7 +230,10 @@ pub fn silu_mul(gate: &Tensor, up: &Tensor) -> Tensor {
assert_eq!(gate.dtype(), DType::BF16, "silu_mul requires BF16");
let out = Tensor::empty(gate.shape(), gate.dtype(), gate.device());
let n = gate.numel();
assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)");
assert!(
n <= i32::MAX as usize,
"tensor too large for i32 kernel param ({n} elements)"
);
let n = n as i32;
unsafe {
launch_silu_mul_bf16(
@@ -94,7 +241,7 @@ pub fn silu_mul(gate: &Tensor, up: &Tensor) -> Tensor {
up.data_ptr() as *const c_void,
out.data_ptr() as *mut c_void,
n,
std::ptr::null_mut(),
xserv_cuda::current_stream_raw(),
);
}
out
@@ -122,7 +269,7 @@ pub fn gpt_oss_glu(gate_up: &Tensor, alpha: f32, limit: f32) -> Tensor {
n_elements,
alpha,
limit,
std::ptr::null_mut(),
xserv_cuda::current_stream_raw(),
);
}
out

View File

@@ -2,8 +2,13 @@ use std::ffi::c_void;
use xserv_tensor::{DType, Device, Tensor};
unsafe extern "C" {
fn launch_argmax_bf16(logits: *const c_void, out_idx: *mut c_void,
rows: i32, cols: i32, stream: *mut c_void);
fn launch_argmax_bf16(
logits: *const c_void,
out_idx: *mut c_void,
rows: i32,
cols: i32,
stream: *mut c_void,
);
}
/// GPU argmax over the last dim of a [rows, cols] BF16 tensor.
@@ -19,7 +24,10 @@ pub fn argmax_bf16_to_host(logits: &Tensor) -> Vec<u32> {
assert_eq!(logits.ndim(), 2, "argmax expects a 2D [rows, cols] tensor");
assert_eq!(logits.dtype(), DType::BF16, "argmax kernel is BF16-only");
assert!(logits.is_contiguous(), "argmax requires contiguous input");
assert!(matches!(logits.device(), Device::Cuda(_)), "argmax requires GPU input");
assert!(
matches!(logits.device(), Device::Cuda(_)),
"argmax requires GPU input"
);
let rows = logits.shape()[0];
let cols = logits.shape()[1];
@@ -35,8 +43,9 @@ pub fn argmax_bf16_to_host(logits: &Tensor) -> Vec<u32> {
launch_argmax_bf16(
logits.data_ptr() as *const c_void,
out.as_mut_ptr() as *mut c_void,
rows as i32, cols as i32,
std::ptr::null_mut(),
rows as i32,
cols as i32,
xserv_cuda::current_stream_raw(),
);
}
@@ -44,9 +53,8 @@ pub fn argmax_bf16_to_host(logits: &Tensor) -> Vec<u32> {
out.copy_to_host(&mut host_bytes).expect("argmax D2H");
drop(out); // returned to pool
let host_i32: &[i32] = unsafe {
std::slice::from_raw_parts(host_bytes.as_ptr() as *const i32, rows)
};
let host_i32: &[i32] =
unsafe { std::slice::from_raw_parts(host_bytes.as_ptr() as *const i32, rows) };
host_i32.iter().map(|&v| v as u32).collect()
}
@@ -62,4 +70,3 @@ pub fn argmax_bf16_single(logits: &Tensor) -> u32 {
};
argmax_bf16_to_host(&view)[0]
}

View File

@@ -6,28 +6,67 @@ use crate::gemm::batched_matmul;
use crate::softmax::softmax;
unsafe extern "C" {
fn launch_causal_mask_f32(scores: *mut c_void, batch: i32, rows: i32, cols: i32,
offset: i32, stream: *mut c_void);
fn launch_causal_mask_bf16(scores: *mut c_void, batch: i32, rows: i32, cols: i32,
offset: i32, stream: *mut c_void);
fn launch_causal_mask_f32(
scores: *mut c_void,
batch: i32,
rows: i32,
cols: i32,
offset: i32,
stream: *mut c_void,
);
fn launch_causal_mask_bf16(
scores: *mut c_void,
batch: i32,
rows: i32,
cols: i32,
offset: i32,
stream: *mut c_void,
);
fn launch_flash_attention_bf16(
q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
batch: i32, num_q_heads: i32, num_kv_heads: i32,
q_len: i32, kv_len: i32, head_dim: i32,
scale: f32, causal: i32, stream: *mut c_void,
q: *const c_void,
k: *const c_void,
v: *const c_void,
o: *mut c_void,
batch: i32,
num_q_heads: i32,
num_kv_heads: i32,
q_len: i32,
kv_len: i32,
head_dim: i32,
scale: f32,
causal: i32,
stream: *mut c_void,
);
fn launch_flash_attention_sinks_bf16(
q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
q: *const c_void,
k: *const c_void,
v: *const c_void,
o: *mut c_void,
sinks: *const c_void,
batch: i32, num_q_heads: i32, num_kv_heads: i32,
q_len: i32, kv_len: i32, head_dim: i32,
scale: f32, causal: i32, window_size: i32, stream: *mut c_void,
batch: i32,
num_q_heads: i32,
num_kv_heads: i32,
q_len: i32,
kv_len: i32,
head_dim: i32,
scale: f32,
causal: i32,
window_size: i32,
stream: *mut c_void,
);
fn launch_decode_attention_bf16(
q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
batch: i32, num_q_heads: i32, num_kv_heads: i32,
kv_len: i32, head_dim: i32,
scale: f32, causal: i32, stream: *mut c_void,
q: *const c_void,
k: *const c_void,
v: *const c_void,
o: *mut c_void,
batch: i32,
num_q_heads: i32,
num_kv_heads: i32,
kv_len: i32,
head_dim: i32,
scale: f32,
causal: i32,
stream: *mut c_void,
);
fn launch_paged_decode_attention_bf16(
q: *const c_void,
@@ -36,9 +75,13 @@ unsafe extern "C" {
o: *mut c_void,
block_tables: *const i32,
context_lens: *const i32,
batch: i32, num_q_heads: i32, num_kv_heads: i32,
head_dim: i32, max_blocks_per_seq: i32,
scale: f32, stream: *mut c_void,
batch: i32,
num_q_heads: i32,
num_kv_heads: i32,
head_dim: i32,
max_blocks_per_seq: i32,
scale: f32,
stream: *mut c_void,
);
fn launch_paged_decode_attention_sinks_bf16(
q: *const c_void,
@@ -48,24 +91,40 @@ unsafe extern "C" {
block_tables: *const i32,
context_lens: *const i32,
sinks: *const c_void,
batch: i32, num_q_heads: i32, num_kv_heads: i32,
head_dim: i32, max_blocks_per_seq: i32,
scale: f32, window_size: i32, stream: *mut c_void,
batch: i32,
num_q_heads: i32,
num_kv_heads: i32,
head_dim: i32,
max_blocks_per_seq: i32,
scale: f32,
window_size: i32,
stream: *mut c_void,
);
fn launch_reshape_and_cache_bf16(
k_src: *const c_void, v_src: *const c_void,
k_pool: *mut c_void, v_pool: *mut c_void,
k_src: *const c_void,
v_src: *const c_void,
k_pool: *mut c_void,
v_pool: *mut c_void,
block_ids: *const c_void,
num_tokens: i32, num_heads: i32,
head_dim: i32, start_pos: i32, block_size: i32,
num_tokens: i32,
num_heads: i32,
head_dim: i32,
start_pos: i32,
block_size: i32,
stream: *mut c_void,
);
fn launch_reshape_and_cache_batched_bf16(
k_src: *const c_void, v_src: *const c_void,
k_pool: *mut c_void, v_pool: *mut c_void,
block_tables: *const c_void, kv_lens: *const c_void,
batch: i32, num_heads: i32,
head_dim: i32, block_size: i32, max_blocks_per_seq: i32,
k_src: *const c_void,
v_src: *const c_void,
k_pool: *mut c_void,
v_pool: *mut c_void,
block_tables: *const c_void,
kv_lens: *const c_void,
batch: i32,
num_heads: i32,
head_dim: i32,
block_size: i32,
max_blocks_per_seq: i32,
stream: *mut c_void,
);
}
@@ -84,20 +143,30 @@ unsafe extern "C" {
/// `block_ids_gpu` must contain at least `(start_pos + num_tokens + block_size - 1) / block_size`
/// valid physical block ids.
pub unsafe fn reshape_and_cache_bf16(
k_src: *const c_void, v_src: *const c_void,
k_pool_ptr: *mut c_void, v_pool_ptr: *mut c_void,
k_src: *const c_void,
v_src: *const c_void,
k_pool_ptr: *mut c_void,
v_pool_ptr: *mut c_void,
block_ids_gpu: *const i32,
num_tokens: usize, num_heads: usize,
head_dim: usize, start_pos: usize, block_size: usize,
num_tokens: usize,
num_heads: usize,
head_dim: usize,
start_pos: usize,
block_size: usize,
stream: *mut c_void,
) {
unsafe {
launch_reshape_and_cache_bf16(
k_src, v_src,
k_pool_ptr, v_pool_ptr,
k_src,
v_src,
k_pool_ptr,
v_pool_ptr,
block_ids_gpu as *const c_void,
num_tokens as i32, num_heads as i32,
head_dim as i32, start_pos as i32, block_size as i32,
num_tokens as i32,
num_heads as i32,
head_dim as i32,
start_pos as i32,
block_size as i32,
stream,
);
}
@@ -113,21 +182,32 @@ pub unsafe fn reshape_and_cache_bf16(
/// All pointers must be on the same GPU. `block_tables` and `kv_lens` must
/// already be synced to the device for the active batch.
pub unsafe fn reshape_and_cache_batched_bf16(
k_src: *const c_void, v_src: *const c_void,
k_pool_ptr: *mut c_void, v_pool_ptr: *mut c_void,
block_tables_gpu: *const i32, kv_lens_gpu: *const i32,
batch: usize, num_heads: usize,
head_dim: usize, block_size: usize, max_blocks_per_seq: usize,
k_src: *const c_void,
v_src: *const c_void,
k_pool_ptr: *mut c_void,
v_pool_ptr: *mut c_void,
block_tables_gpu: *const i32,
kv_lens_gpu: *const i32,
batch: usize,
num_heads: usize,
head_dim: usize,
block_size: usize,
max_blocks_per_seq: usize,
stream: *mut c_void,
) {
unsafe {
launch_reshape_and_cache_batched_bf16(
k_src, v_src,
k_pool_ptr, v_pool_ptr,
k_src,
v_src,
k_pool_ptr,
v_pool_ptr,
block_tables_gpu as *const c_void,
kv_lens_gpu as *const c_void,
batch as i32, num_heads as i32,
head_dim as i32, block_size as i32, max_blocks_per_seq as i32,
batch as i32,
num_heads as i32,
head_dim as i32,
block_size as i32,
max_blocks_per_seq as i32,
stream,
);
}
@@ -143,13 +223,19 @@ fn apply_causal_mask(scores: &Tensor, offset: usize) {
match scores.dtype() {
DType::F32 => launch_causal_mask_f32(
scores.data_ptr() as *mut c_void,
batch as i32, rows as i32, cols as i32, offset as i32,
std::ptr::null_mut(),
batch as i32,
rows as i32,
cols as i32,
offset as i32,
xserv_cuda::current_stream_raw(),
),
DType::BF16 => launch_causal_mask_bf16(
scores.data_ptr() as *mut c_void,
batch as i32, rows as i32, cols as i32, offset as i32,
std::ptr::null_mut(),
batch as i32,
rows as i32,
cols as i32,
offset as i32,
xserv_cuda::current_stream_raw(),
),
_ => panic!("unsupported dtype for causal mask"),
}
@@ -214,11 +300,7 @@ pub fn decode_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Tensor {
let kv_len = k.shape()[2];
let scale = 1.0 / (head_dim as f32).sqrt();
let output = Tensor::empty(
&[batch, num_q_heads, 1, head_dim],
DType::BF16,
q.device(),
);
let output = Tensor::empty(&[batch, num_q_heads, 1, head_dim], DType::BF16, q.device());
unsafe {
launch_decode_attention_bf16(
@@ -233,7 +315,7 @@ pub fn decode_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Tensor {
head_dim as i32,
scale,
1, // causal (always 1 for decode)
std::ptr::null_mut(),
xserv_cuda::current_stream_raw(),
);
}
@@ -266,8 +348,14 @@ pub fn flash_attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tens
assert_eq!(k.shape(), &[batch, num_kv_heads, kv_len, head_dim]);
assert_eq!(v.shape(), &[batch, num_kv_heads, kv_len, head_dim]);
assert!(num_q_heads % num_kv_heads == 0, "num_q_heads must be divisible by num_kv_heads");
assert!(head_dim <= 128, "flash_attention supports head_dim up to 128");
assert!(
num_q_heads % num_kv_heads == 0,
"num_q_heads must be divisible by num_kv_heads"
);
assert!(
head_dim <= 128,
"flash_attention supports head_dim up to 128"
);
// Dispatch to specialized decode kernel for single-token generation
if q_len == 1 {
@@ -295,7 +383,7 @@ pub fn flash_attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tens
head_dim as i32,
scale,
if causal { 1 } else { 0 },
std::ptr::null_mut(),
xserv_cuda::current_stream_raw(),
);
}
@@ -333,10 +421,18 @@ pub fn flash_attention_sinks(
assert_eq!(v.shape(), &[batch, num_kv_heads, kv_len, head_dim]);
assert!(num_q_heads % num_kv_heads == 0);
assert!(head_dim <= 128);
assert_eq!(sinks.shape()[0], num_q_heads, "sinks must have num_q_heads entries");
assert_eq!(
sinks.shape()[0],
num_q_heads,
"sinks must have num_q_heads entries"
);
let scale = 1.0 / (head_dim as f32).sqrt();
let output = Tensor::empty(&[batch, num_q_heads, q_len, head_dim], DType::BF16, q.device());
let output = Tensor::empty(
&[batch, num_q_heads, q_len, head_dim],
DType::BF16,
q.device(),
);
unsafe {
launch_flash_attention_sinks_bf16(
@@ -354,7 +450,7 @@ pub fn flash_attention_sinks(
scale,
1, // always causal
window_size as i32,
std::ptr::null_mut(),
xserv_cuda::current_stream_raw(),
);
}
@@ -383,17 +479,20 @@ pub fn paged_decode_attention(
max_blocks_per_seq: usize,
) -> Tensor {
assert_eq!(q.ndim(), 4);
assert_eq!(q.shape()[2], 1, "paged_decode_attention requires q_len == 1");
assert_eq!(
q.shape()[2],
1,
"paged_decode_attention requires q_len == 1"
);
assert_eq!(q.dtype(), DType::BF16);
assert!(num_q_heads % num_kv_heads == 0, "GQA: num_q_heads must be divisible by num_kv_heads");
assert!(
num_q_heads % num_kv_heads == 0,
"GQA: num_q_heads must be divisible by num_kv_heads"
);
assert!(head_dim <= 128);
let scale = 1.0 / (head_dim as f32).sqrt();
let output = Tensor::empty(
&[batch, num_q_heads, 1, head_dim],
DType::BF16,
q.device(),
);
let output = Tensor::empty(&[batch, num_q_heads, 1, head_dim], DType::BF16, q.device());
unsafe {
launch_paged_decode_attention_bf16(
@@ -409,7 +508,7 @@ pub fn paged_decode_attention(
head_dim as i32,
max_blocks_per_seq as i32,
scale,
std::ptr::null_mut(),
xserv_cuda::current_stream_raw(),
);
}
@@ -442,11 +541,7 @@ pub fn paged_decode_attention_sinks(
assert!(head_dim <= 128);
let scale = 1.0 / (head_dim as f32).sqrt();
let output = Tensor::empty(
&[batch, num_q_heads, 1, head_dim],
DType::BF16,
q.device(),
);
let output = Tensor::empty(&[batch, num_q_heads, 1, head_dim], DType::BF16, q.device());
unsafe {
launch_paged_decode_attention_sinks_bf16(
@@ -464,7 +559,7 @@ pub fn paged_decode_attention_sinks(
max_blocks_per_seq as i32,
scale,
window_size as i32,
std::ptr::null_mut(),
xserv_cuda::current_stream_raw(),
);
}

View File

@@ -5,104 +5,302 @@ use std::ffi::c_void;
// Re-declare the extern functions we need (same as in the individual modules)
unsafe extern "C" {
fn launch_rmsnorm_bf16(x: *const c_void, gamma: *const c_void, out: *mut c_void,
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
fn launch_add_rmsnorm_bf16(x: *const c_void, residual: *const c_void, gamma: *const c_void,
normed_out: *mut c_void, sum_out: *mut c_void,
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
fn launch_silu_mul_bf16(gate: *const c_void, up: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_add_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_embedding_bf16(table: *const c_void, token_ids: *const c_void, out: *mut c_void,
num_tokens: i32, hidden_size: i32, vocab_size: i32, stream: *mut c_void);
fn launch_reshape_heads_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
fn launch_merge_heads_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
fn launch_transpose_hsd_to_shd_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
fn launch_transpose_shd_to_hsd_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
fn launch_rope_bf16(x: *mut c_void, cos_cache: *const c_void, sin_cache: *const c_void,
positions: *const c_void, num_tokens: i32, num_heads: i32,
head_dim: i32, stream: *mut c_void);
fn launch_gemv_bf16(x: *const c_void, w: *const c_void, y_bf16: *mut c_void, y_fp32_buf: *mut c_void,
k: i32, n: i32, stream: *mut c_void);
fn launch_rmsnorm_bf16(
x: *const c_void,
gamma: *const c_void,
out: *mut c_void,
rows: i32,
hidden_size: i32,
eps: f32,
stream: *mut c_void,
);
fn launch_add_rmsnorm_bf16(
x: *const c_void,
residual: *const c_void,
gamma: *const c_void,
normed_out: *mut c_void,
sum_out: *mut c_void,
rows: i32,
hidden_size: i32,
eps: f32,
stream: *mut c_void,
);
fn launch_silu_mul_bf16(
gate: *const c_void,
up: *const c_void,
out: *mut c_void,
n: i32,
stream: *mut c_void,
);
fn launch_add_bf16(
a: *const c_void,
b: *const c_void,
out: *mut c_void,
n: i32,
stream: *mut c_void,
);
fn launch_embedding_bf16(
table: *const c_void,
token_ids: *const c_void,
out: *mut c_void,
num_tokens: i32,
hidden_size: i32,
vocab_size: i32,
stream: *mut c_void,
);
fn launch_reshape_heads_bf16(
inp: *const c_void,
out: *mut c_void,
seq_len: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
);
fn launch_merge_heads_bf16(
inp: *const c_void,
out: *mut c_void,
seq_len: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
);
fn launch_transpose_hsd_to_shd_bf16(
inp: *const c_void,
out: *mut c_void,
seq_len: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
);
fn launch_transpose_shd_to_hsd_bf16(
inp: *const c_void,
out: *mut c_void,
seq_len: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
);
fn launch_rope_bf16(
x: *mut c_void,
cos_cache: *const c_void,
sin_cache: *const c_void,
positions: *const c_void,
num_tokens: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
);
fn launch_gemv_bf16(
x: *const c_void,
w: *const c_void,
y_bf16: *mut c_void,
y_fp32_buf: *mut c_void,
k: i32,
n: i32,
stream: *mut c_void,
);
fn launch_decode_attention_bf16(
q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
batch: i32, num_q_heads: i32, num_kv_heads: i32,
kv_len: i32, head_dim: i32,
scale: f32, causal: i32, stream: *mut c_void,
q: *const c_void,
k: *const c_void,
v: *const c_void,
o: *mut c_void,
batch: i32,
num_q_heads: i32,
num_kv_heads: i32,
kv_len: i32,
head_dim: i32,
scale: f32,
causal: i32,
stream: *mut c_void,
);
}
/// Raw rmsnorm dispatch: writes to pre-allocated `out`.
pub unsafe fn rmsnorm_bf16(x: *const c_void, gamma: *const c_void, out: *mut c_void,
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void) {
pub unsafe fn rmsnorm_bf16(
x: *const c_void,
gamma: *const c_void,
out: *mut c_void,
rows: i32,
hidden_size: i32,
eps: f32,
stream: *mut c_void,
) {
launch_rmsnorm_bf16(x, gamma, out, rows, hidden_size, eps, stream);
}
/// Raw add_rmsnorm dispatch.
pub unsafe fn add_rmsnorm_bf16(x: *const c_void, residual: *const c_void, gamma: *const c_void,
normed_out: *mut c_void, sum_out: *mut c_void,
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void) {
launch_add_rmsnorm_bf16(x, residual, gamma, normed_out, sum_out, rows, hidden_size, eps, stream);
pub unsafe fn add_rmsnorm_bf16(
x: *const c_void,
residual: *const c_void,
gamma: *const c_void,
normed_out: *mut c_void,
sum_out: *mut c_void,
rows: i32,
hidden_size: i32,
eps: f32,
stream: *mut c_void,
) {
launch_add_rmsnorm_bf16(
x,
residual,
gamma,
normed_out,
sum_out,
rows,
hidden_size,
eps,
stream,
);
}
/// Raw silu_mul dispatch.
pub unsafe fn silu_mul_bf16(gate: *const c_void, up: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void) {
pub unsafe fn silu_mul_bf16(
gate: *const c_void,
up: *const c_void,
out: *mut c_void,
n: i32,
stream: *mut c_void,
) {
launch_silu_mul_bf16(gate, up, out, n, stream);
}
/// Raw add dispatch.
pub unsafe fn add_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void) {
pub unsafe fn add_bf16(
a: *const c_void,
b: *const c_void,
out: *mut c_void,
n: i32,
stream: *mut c_void,
) {
launch_add_bf16(a, b, out, n, stream);
}
/// Raw embedding dispatch.
pub unsafe fn embedding_bf16(table: *const c_void, token_ids: *const c_void, out: *mut c_void,
num_tokens: i32, hidden_size: i32, vocab_size: i32, stream: *mut c_void) {
launch_embedding_bf16(table, token_ids, out, num_tokens, hidden_size, vocab_size, stream);
pub unsafe fn embedding_bf16(
table: *const c_void,
token_ids: *const c_void,
out: *mut c_void,
num_tokens: i32,
hidden_size: i32,
vocab_size: i32,
stream: *mut c_void,
) {
launch_embedding_bf16(
table,
token_ids,
out,
num_tokens,
hidden_size,
vocab_size,
stream,
);
}
/// Raw reshape_heads dispatch.
pub unsafe fn reshape_heads_bf16(inp: *const c_void, out: *mut c_void,
seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void) {
pub unsafe fn reshape_heads_bf16(
inp: *const c_void,
out: *mut c_void,
seq_len: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
) {
launch_reshape_heads_bf16(inp, out, seq_len, num_heads, head_dim, stream);
}
/// Raw merge_heads dispatch.
pub unsafe fn merge_heads_bf16(inp: *const c_void, out: *mut c_void,
seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void) {
pub unsafe fn merge_heads_bf16(
inp: *const c_void,
out: *mut c_void,
seq_len: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
) {
launch_merge_heads_bf16(inp, out, seq_len, num_heads, head_dim, stream);
}
/// Raw transpose HSD->SHD dispatch.
pub unsafe fn transpose_hsd_to_shd_bf16(inp: *const c_void, out: *mut c_void,
seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void) {
pub unsafe fn transpose_hsd_to_shd_bf16(
inp: *const c_void,
out: *mut c_void,
seq_len: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
) {
launch_transpose_hsd_to_shd_bf16(inp, out, seq_len, num_heads, head_dim, stream);
}
/// Raw transpose SHD->HSD dispatch.
pub unsafe fn transpose_shd_to_hsd_bf16(inp: *const c_void, out: *mut c_void,
seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void) {
pub unsafe fn transpose_shd_to_hsd_bf16(
inp: *const c_void,
out: *mut c_void,
seq_len: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
) {
launch_transpose_shd_to_hsd_bf16(inp, out, seq_len, num_heads, head_dim, stream);
}
/// Raw RoPE dispatch (in-place).
pub unsafe fn rope_bf16(x: *mut c_void, cos_cache: *const c_void, sin_cache: *const c_void,
positions: *const c_void, num_tokens: i32, num_heads: i32,
head_dim: i32, stream: *mut c_void) {
launch_rope_bf16(x, cos_cache, sin_cache, positions, num_tokens, num_heads, head_dim, stream);
pub unsafe fn rope_bf16(
x: *mut c_void,
cos_cache: *const c_void,
sin_cache: *const c_void,
positions: *const c_void,
num_tokens: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
) {
launch_rope_bf16(
x, cos_cache, sin_cache, positions, num_tokens, num_heads, head_dim, stream,
);
}
/// Raw GEMV dispatch (BF16, M=1). Caller must provide fp32 accumulator buffer.
pub unsafe fn gemv_bf16(x: *const c_void, w: *const c_void, y_bf16: *mut c_void,
y_fp32_buf: *mut c_void, k: i32, n: i32, stream: *mut c_void) {
pub unsafe fn gemv_bf16(
x: *const c_void,
w: *const c_void,
y_bf16: *mut c_void,
y_fp32_buf: *mut c_void,
k: i32,
n: i32,
stream: *mut c_void,
) {
launch_gemv_bf16(x, w, y_bf16, y_fp32_buf, k, n, stream);
}
/// Raw decode attention dispatch.
pub unsafe fn decode_attention_bf16(q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
batch: i32, num_q_heads: i32, num_kv_heads: i32,
kv_len: i32, head_dim: i32,
scale: f32, stream: *mut c_void) {
launch_decode_attention_bf16(q, k, v, o, batch, num_q_heads, num_kv_heads, kv_len, head_dim, scale, 1, stream);
pub unsafe fn decode_attention_bf16(
q: *const c_void,
k: *const c_void,
v: *const c_void,
o: *mut c_void,
batch: i32,
num_q_heads: i32,
num_kv_heads: i32,
kv_len: i32,
head_dim: i32,
scale: f32,
stream: *mut c_void,
) {
launch_decode_attention_bf16(
q,
k,
v,
o,
batch,
num_q_heads,
num_kv_heads,
kv_len,
head_dim,
scale,
1,
stream,
);
}
// cuBLAS FFI

View File

@@ -1,12 +1,25 @@
use std::ffi::c_void;
use xserv_cuda::GpuBuffer;
use xserv_tensor::{DType, Device, Tensor};
unsafe extern "C" {
fn launch_embedding_f32(table: *const c_void, token_ids: *const c_void, out: *mut c_void,
num_tokens: i32, hidden_size: i32, vocab_size: i32, stream: *mut c_void);
fn launch_embedding_bf16(table: *const c_void, token_ids: *const c_void, out: *mut c_void,
num_tokens: i32, hidden_size: i32, vocab_size: i32, stream: *mut c_void);
fn launch_embedding_f32(
table: *const c_void,
token_ids: *const c_void,
out: *mut c_void,
num_tokens: i32,
hidden_size: i32,
vocab_size: i32,
stream: *mut c_void,
);
fn launch_embedding_bf16(
table: *const c_void,
token_ids: *const c_void,
out: *mut c_void,
num_tokens: i32,
hidden_size: i32,
vocab_size: i32,
stream: *mut c_void,
);
}
/// Embedding lookup: table[token_ids[i]] for each i.
@@ -19,8 +32,14 @@ pub fn embedding(table: &Tensor, token_ids: &[u32]) -> Tensor {
let hidden_size = table.shape()[1];
let num_tokens = token_ids.len();
let vocab_size = table.shape()[0];
assert!(num_tokens <= i32::MAX as usize, "too many tokens for i32 kernel param");
assert!(hidden_size <= i32::MAX as usize, "hidden_size too large for i32 kernel param");
assert!(
num_tokens <= i32::MAX as usize,
"too many tokens for i32 kernel param"
);
assert!(
hidden_size <= i32::MAX as usize,
"hidden_size too large for i32 kernel param"
);
// Upload token_ids to GPU
let ids_bytes = unsafe {
@@ -29,26 +48,51 @@ pub fn embedding(table: &Tensor, token_ids: &[u32]) -> Tensor {
num_tokens * std::mem::size_of::<u32>(),
)
};
let mut ids_gpu = xserv_cuda::allocator::cached_alloc(ids_bytes.len()).expect("alloc token_ids");
let mut ids_gpu =
xserv_cuda::allocator::cached_alloc(ids_bytes.len()).expect("alloc token_ids");
ids_gpu.copy_from_host(ids_bytes).unwrap();
for &tid in token_ids {
assert!((tid as usize) < vocab_size, "token_id {tid} out of bounds (vocab_size={vocab_size})");
assert!(
(tid as usize) < vocab_size,
"token_id {tid} out of bounds (vocab_size={vocab_size})"
);
}
embedding_device_ids(table, ids_gpu.as_ptr() as *const c_void, num_tokens)
}
/// Embedding lookup with token ids already on the GPU (u32, [num_tokens]).
/// Used by the CUDA-graph decode path, where ids live in a persistent device
/// buffer updated outside the captured region (no bounds check possible here).
pub fn embedding_device_ids(table: &Tensor, ids_gpu: *const c_void, num_tokens: usize) -> Tensor {
assert_eq!(table.ndim(), 2);
assert!(table.is_contiguous());
assert!(matches!(table.device(), Device::Cuda(_)));
let hidden_size = table.shape()[1];
let vocab_size = table.shape()[0];
let out = Tensor::empty(&[num_tokens, hidden_size], table.dtype(), table.device());
unsafe {
match table.dtype() {
DType::F32 => launch_embedding_f32(
table.data_ptr() as _, ids_gpu.as_ptr() as _,
table.data_ptr() as _,
ids_gpu,
out.data_ptr() as *mut c_void,
num_tokens as i32, hidden_size as i32, vocab_size as i32, std::ptr::null_mut(),
num_tokens as i32,
hidden_size as i32,
vocab_size as i32,
xserv_cuda::current_stream_raw(),
),
DType::BF16 => launch_embedding_bf16(
table.data_ptr() as _, ids_gpu.as_ptr() as _,
table.data_ptr() as _,
ids_gpu,
out.data_ptr() as *mut c_void,
num_tokens as i32, hidden_size as i32, vocab_size as i32, std::ptr::null_mut(),
num_tokens as i32,
hidden_size as i32,
vocab_size as i32,
xserv_cuda::current_stream_raw(),
),
_ => panic!("unsupported dtype for embedding"),
}

View File

@@ -1,14 +1,22 @@
use std::cell::RefCell;
use std::ffi::c_void;
use xserv_cuda::error::{self, Result};
use xserv_cuda::GpuBuffer;
use xserv_cuda::error::{self, Result};
use xserv_tensor::{DType, Device, Tensor};
const CUBLAS_WORKSPACE_BYTES: usize = 32 * 1024 * 1024;
// GEMV: single-kernel, no FP32 temp buffer needed
unsafe extern "C" {
fn launch_gemv_bf16(x: *const c_void, w: *const c_void, y_bf16: *mut c_void, y_fp32_buf: *mut c_void, k: i32, n: i32, stream: *mut c_void);
fn launch_gemv_bf16(
x: *const c_void,
w: *const c_void,
y_bf16: *mut c_void,
y_fp32_buf: *mut c_void,
k: i32,
n: i32,
stream: *mut c_void,
);
}
#[derive(Debug, Clone, Copy)]
@@ -20,10 +28,42 @@ pub enum GemmBackend {
// --- FFI: custom CUDA kernels ---
unsafe extern "C" {
fn launch_gemm_naive_f32(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
fn launch_gemm_naive_bf16(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
fn launch_gemm_tiled_f32(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
fn launch_gemm_tiled_bf16(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
fn launch_gemm_naive_f32(
a: *const c_void,
b: *const c_void,
c: *mut c_void,
m: i32,
n: i32,
k: i32,
stream: *mut c_void,
);
fn launch_gemm_naive_bf16(
a: *const c_void,
b: *const c_void,
c: *mut c_void,
m: i32,
n: i32,
k: i32,
stream: *mut c_void,
);
fn launch_gemm_tiled_f32(
a: *const c_void,
b: *const c_void,
c: *mut c_void,
m: i32,
n: i32,
k: i32,
stream: *mut c_void,
);
fn launch_gemm_tiled_bf16(
a: *const c_void,
b: *const c_void,
c: *mut c_void,
m: i32,
n: i32,
k: i32,
stream: *mut c_void,
);
}
// --- FFI: cuBLAS ---
@@ -46,25 +86,46 @@ unsafe extern "C" {
fn cublasSetWorkspace_v2(handle: CublasHandle, workspace: *mut c_void, size: usize) -> i32;
fn cublasGemmEx(
handle: CublasHandle,
transa: i32, transb: i32,
m: i32, n: i32, k: i32,
transa: i32,
transb: i32,
m: i32,
n: i32,
k: i32,
alpha: *const c_void,
a: *const c_void, a_type: i32, lda: i32,
b: *const c_void, b_type: i32, ldb: i32,
a: *const c_void,
a_type: i32,
lda: i32,
b: *const c_void,
b_type: i32,
ldb: i32,
beta: *const c_void,
c: *mut c_void, c_type: i32, ldc: i32,
c: *mut c_void,
c_type: i32,
ldc: i32,
compute_type: i32,
algo: i32,
) -> i32;
fn cublasGemmStridedBatchedEx(
handle: CublasHandle,
transa: i32, transb: i32,
m: i32, n: i32, k: i32,
transa: i32,
transb: i32,
m: i32,
n: i32,
k: i32,
alpha: *const c_void,
a: *const c_void, a_type: i32, lda: i32, stride_a: i64,
b: *const c_void, b_type: i32, ldb: i32, stride_b: i64,
a: *const c_void,
a_type: i32,
lda: i32,
stride_a: i64,
b: *const c_void,
b_type: i32,
ldb: i32,
stride_b: i64,
beta: *const c_void,
c: *mut c_void, c_type: i32, ldc: i32, stride_c: i64,
c: *mut c_void,
c_type: i32,
ldc: i32,
stride_c: i64,
batch_count: i32,
compute_type: i32,
algo: i32,
@@ -89,9 +150,16 @@ impl CublasContext {
// set, so we keep the GpuBuffer in this struct.
let mut workspace = GpuBuffer::alloc(CUBLAS_WORKSPACE_BYTES)?;
error::check(unsafe {
cublasSetWorkspace_v2(handle, workspace.as_mut_ptr() as *mut c_void, CUBLAS_WORKSPACE_BYTES)
cublasSetWorkspace_v2(
handle,
workspace.as_mut_ptr() as *mut c_void,
CUBLAS_WORKSPACE_BYTES,
)
})?;
Ok(Self { handle, _workspace: Some(workspace) })
Ok(Self {
handle,
_workspace: Some(workspace),
})
}
}
@@ -123,9 +191,7 @@ where
/// Get the thread-local cuBLAS handle for use with dispatch module.
pub fn cublas_handle() -> CublasHandle {
CUBLAS_CTX.with(|cell| {
cell.borrow().handle
})
CUBLAS_CTX.with(|cell| cell.borrow().handle)
}
/// Matrix multiplication: C = A @ B
@@ -136,8 +202,14 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
assert_eq!(b.ndim(), 2);
assert_eq!(a.shape()[1], b.shape()[0], "inner dimension mismatch");
assert_eq!(a.dtype(), b.dtype(), "dtype mismatch");
assert!(a.is_contiguous() && b.is_contiguous(), "matmul requires contiguous tensors");
assert!(matches!(a.device(), Device::Cuda(_)), "matmul requires GPU tensors");
assert!(
a.is_contiguous() && b.is_contiguous(),
"matmul requires contiguous tensors"
);
assert!(
matches!(a.device(), Device::Cuda(_)),
"matmul requires GPU tensors"
);
let m = a.shape()[0];
let k = a.shape()[1];
@@ -151,35 +223,66 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
let a_ptr = a.data_ptr() as *const c_void;
let b_ptr = b.data_ptr() as *const c_void;
let c_ptr = c.data_ptr() as *mut c_void;
let null_stream = std::ptr::null_mut();
let null_stream = xserv_cuda::current_stream_raw();
match backend {
GemmBackend::Naive => {
unsafe {
match dtype {
DType::F32 => launch_gemm_naive_f32(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
DType::BF16 => launch_gemm_naive_bf16(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
_ => panic!("unsupported dtype for naive GEMM"),
}
GemmBackend::Naive => unsafe {
match dtype {
DType::F32 => launch_gemm_naive_f32(
a_ptr,
b_ptr,
c_ptr,
m as i32,
n as i32,
k as i32,
null_stream,
),
DType::BF16 => launch_gemm_naive_bf16(
a_ptr,
b_ptr,
c_ptr,
m as i32,
n as i32,
k as i32,
null_stream,
),
_ => panic!("unsupported dtype for naive GEMM"),
}
}
GemmBackend::Tiled => {
unsafe {
match dtype {
DType::F32 => launch_gemm_tiled_f32(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
DType::BF16 => launch_gemm_tiled_bf16(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
_ => panic!("unsupported dtype for tiled GEMM"),
}
},
GemmBackend::Tiled => unsafe {
match dtype {
DType::F32 => launch_gemm_tiled_f32(
a_ptr,
b_ptr,
c_ptr,
m as i32,
n as i32,
k as i32,
null_stream,
),
DType::BF16 => launch_gemm_tiled_bf16(
a_ptr,
b_ptr,
c_ptr,
m as i32,
n as i32,
k as i32,
null_stream,
),
_ => panic!("unsupported dtype for tiled GEMM"),
}
}
},
GemmBackend::CuBlas => {
if m == 1 && dtype == DType::BF16 && n >= 256 {
let mut fp32_buf = xserv_cuda::allocator::cached_alloc(n * 4).unwrap();
unsafe {
launch_gemv_bf16(
a_ptr, b_ptr, c_ptr,
a_ptr,
b_ptr,
c_ptr,
fp32_buf.as_mut_ptr() as *mut c_void,
k as i32, n as i32,
k as i32,
n as i32,
null_stream,
);
}
@@ -197,16 +300,26 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
cublasSetStream_v2(handle, null_stream);
error::check(cublasGemmEx(
handle,
CUBLAS_OP_N, CUBLAS_OP_N,
n as i32, m as i32, k as i32,
CUBLAS_OP_N,
CUBLAS_OP_N,
n as i32,
m as i32,
k as i32,
&alpha as *const f32 as *const c_void,
b_ptr, b_type, n as i32,
a_ptr, a_type, k as i32,
b_ptr,
b_type,
n as i32,
a_ptr,
a_type,
k as i32,
&beta as *const f32 as *const c_void,
c_ptr, c_type, n as i32,
c_ptr,
c_type,
n as i32,
CUBLAS_COMPUTE_32F,
-1,
)).expect("cuBLAS GEMM failed");
))
.expect("cuBLAS GEMM failed");
});
}
}
@@ -260,21 +373,34 @@ pub fn batched_matmul(a: &Tensor, b: &Tensor) -> Tensor {
let stride_c = (m * n) as i64;
with_cublas(|handle| unsafe {
cublasSetStream_v2(handle, std::ptr::null_mut());
cublasSetStream_v2(handle, xserv_cuda::current_stream_raw());
// Row-major trick: C = A @ B ⟺ C^T = B^T @ A^T (col-major)
error::check(cublasGemmStridedBatchedEx(
handle,
CUBLAS_OP_N, CUBLAS_OP_N,
n as i32, m as i32, k as i32,
CUBLAS_OP_N,
CUBLAS_OP_N,
n as i32,
m as i32,
k as i32,
&alpha as *const f32 as *const c_void,
b.data_ptr() as _, b_type, n as i32, stride_b,
a.data_ptr() as _, a_type, k as i32, stride_a,
b.data_ptr() as _,
b_type,
n as i32,
stride_b,
a.data_ptr() as _,
a_type,
k as i32,
stride_a,
&beta as *const f32 as *const c_void,
c.data_ptr() as *mut c_void, c_type, n as i32, stride_c,
c.data_ptr() as *mut c_void,
c_type,
n as i32,
stride_c,
batch as i32,
CUBLAS_COMPUTE_32F,
-1,
)).expect("cuBLAS batched GEMM failed");
))
.expect("cuBLAS batched GEMM failed");
});
c
}

View File

@@ -2,10 +2,26 @@ use std::ffi::c_void;
use xserv_tensor::{DType, Device, Tensor};
unsafe extern "C" {
fn launch_layernorm_f32(x: *const c_void, gamma: *const c_void, beta: *const c_void,
out: *mut c_void, rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
fn launch_layernorm_bf16(x: *const c_void, gamma: *const c_void, beta: *const c_void,
out: *mut c_void, rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
fn launch_layernorm_f32(
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
out: *mut c_void,
rows: i32,
hidden_size: i32,
eps: f32,
stream: *mut c_void,
);
fn launch_layernorm_bf16(
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
out: *mut c_void,
rows: i32,
hidden_size: i32,
eps: f32,
stream: *mut c_void,
);
}
pub fn layernorm(x: &Tensor, gamma: &Tensor, beta: &Tensor, eps: f32) -> Tensor {
@@ -17,21 +33,37 @@ pub fn layernorm(x: &Tensor, gamma: &Tensor, beta: &Tensor, eps: f32) -> Tensor
assert_eq!(beta.shape(), &[hidden_size]);
let rows = x.numel() / hidden_size;
assert!(rows <= i32::MAX as usize, "too many rows for i32 kernel param");
assert!(hidden_size <= i32::MAX as usize, "hidden_size too large for i32 kernel param");
assert!(
rows <= i32::MAX as usize,
"too many rows for i32 kernel param"
);
assert!(
hidden_size <= i32::MAX as usize,
"hidden_size too large for i32 kernel param"
);
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
unsafe {
match x.dtype() {
DType::F32 => launch_layernorm_f32(
x.data_ptr() as _, gamma.data_ptr() as _, beta.data_ptr() as _,
x.data_ptr() as _,
gamma.data_ptr() as _,
beta.data_ptr() as _,
out.data_ptr() as *mut c_void,
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(),
rows as i32,
hidden_size as i32,
eps,
xserv_cuda::current_stream_raw(),
),
DType::BF16 => launch_layernorm_bf16(
x.data_ptr() as _, gamma.data_ptr() as _, beta.data_ptr() as _,
x.data_ptr() as _,
gamma.data_ptr() as _,
beta.data_ptr() as _,
out.data_ptr() as *mut c_void,
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(),
rows as i32,
hidden_size as i32,
eps,
xserv_cuda::current_stream_raw(),
),
_ => panic!("unsupported dtype for layernorm"),
}

View File

@@ -12,16 +12,22 @@ pub mod rope;
pub mod softmax;
pub mod transpose;
pub use activation::{add, gelu, gpt_oss_glu, mul, scale, silu, silu_mul};
pub use activation::{add, bias_add_2d, gelu, gpt_oss_glu, mul, scale, silu, silu_mul};
pub use argmax::{argmax_bf16_single, argmax_bf16_to_host};
pub use transpose::{merge_heads_gpu, repeat_kv_gpu, reshape_heads_gpu, strided_to_contiguous_gpu, transpose_for_rope_gpu, transpose_from_rope_gpu};
pub use attention::{attention, decode_attention, flash_attention, flash_attention_sinks, paged_decode_attention, paged_decode_attention_sinks, reshape_and_cache_bf16, reshape_and_cache_batched_bf16};
pub use embedding::embedding;
pub use gemm::{batched_matmul, matmul, GemmBackend};
pub use attention::{
attention, decode_attention, flash_attention, flash_attention_sinks, paged_decode_attention,
paged_decode_attention_sinks, reshape_and_cache_batched_bf16, reshape_and_cache_bf16,
};
pub use embedding::{embedding, embedding_device_ids};
pub use gemm::{GemmBackend, batched_matmul, matmul};
pub use layernorm::layernorm;
pub use rmsnorm::{add_rmsnorm, rmsnorm};
pub use rope::{rope_inplace, RopeCache};
pub use rope::{RopeCache, rope_inplace, rope_inplace_device_pos};
pub use softmax::softmax;
pub use transpose::{
merge_heads_gpu, repeat_kv_gpu, reshape_heads_gpu, strided_to_contiguous_gpu,
transpose_for_rope_gpu, transpose_from_rope_gpu,
};
/// Register GPU kernels with the tensor crate. Call once at startup.
pub fn init() {

View File

@@ -1,43 +1,113 @@
use std::ffi::c_void;
use xserv_tensor::{DType, Device, Tensor};
use xserv_tensor::{DType, Tensor};
use crate::gemm::{cublas_handle, CublasHandle};
use crate::gemm::{CublasHandle, cublas_handle};
unsafe extern "C" {
fn launch_moe_topk_softmax_bf16(
router_logits: *const c_void,
topk_ids: *mut c_void, topk_weights: *mut c_void,
num_tokens: i32, num_experts: i32, top_k: i32,
topk_ids: *mut c_void,
topk_weights: *mut c_void,
num_tokens: i32,
num_experts: i32,
top_k: i32,
stream: *mut c_void,
);
fn launch_moe_replicate_bf16(
x: *const c_void, x_rep: *mut c_void,
num_tokens: i32, hidden: i32, local_experts: i32,
x: *const c_void,
x_rep: *mut c_void,
num_tokens: i32,
hidden: i32,
local_experts: i32,
stream: *mut c_void,
);
fn launch_moe_bias_add_3d_bf16(
x: *mut c_void, bias: *const c_void,
batch: i32, num_tokens: i32, dim: i32,
x: *mut c_void,
bias: *const c_void,
batch: i32,
num_tokens: i32,
dim: i32,
stream: *mut c_void,
);
fn launch_moe_weighted_sum_bf16(
expert_out: *const c_void,
topk_ids: *const c_void, topk_weights: *const c_void,
topk_ids: *const c_void,
topk_weights: *const c_void,
out: *mut c_void,
num_tokens: i32, hidden: i32, top_k: i32,
expert_start: i32, local_experts: i32,
num_tokens: i32,
hidden: i32,
top_k: i32,
expert_start: i32,
local_experts: i32,
stream: *mut c_void,
);
fn launch_moe_sparse_gemv_fp8_bf16(
x: *const c_void,
w: *const c_void,
w_scales: *const c_void,
bias: *const c_void,
topk_ids: *const c_void,
y: *mut c_void,
num_tokens: i32,
n: i32,
k: i32,
top_k: i32,
expert_start: i32,
local_experts: i32,
x_per_slot: i32,
stream: *mut c_void,
);
fn launch_moe_sparse_gemv_mxfp4_bf16(
x: *const c_void,
w_packed: *const c_void,
w_scales: *const c_void,
bias: *const c_void,
topk_ids: *const c_void,
y: *mut c_void,
num_tokens: i32,
n: i32,
k: i32,
top_k: i32,
expert_start: i32,
local_experts: i32,
x_per_slot: i32,
stream: *mut c_void,
);
fn launch_moe_weighted_sum_sparse_bf16(
down: *const c_void,
topk_ids: *const c_void,
topk_weights: *const c_void,
out: *mut c_void,
num_tokens: i32,
hidden: i32,
top_k: i32,
expert_start: i32,
local_experts: i32,
stream: *mut c_void,
);
fn cublasGemmStridedBatchedEx(
handle: CublasHandle,
transa: i32, transb: i32,
m: i32, n: i32, k: i32,
transa: i32,
transb: i32,
m: i32,
n: i32,
k: i32,
alpha: *const c_void,
a: *const c_void, a_type: i32, lda: i32, stride_a: i64,
b: *const c_void, b_type: i32, ldb: i32, stride_b: i64,
a: *const c_void,
a_type: i32,
lda: i32,
stride_a: i64,
b: *const c_void,
b_type: i32,
ldb: i32,
stride_b: i64,
beta: *const c_void,
c: *mut c_void, c_type: i32, ldc: i32, stride_c: i64,
c: *mut c_void,
c_type: i32,
ldc: i32,
stride_c: i64,
batch_count: i32,
compute_type: i32,
algo: i32,
@@ -65,6 +135,9 @@ pub fn moe_topk_softmax(
let num_tokens = router_logits.shape()[0];
assert_eq!(router_logits.shape()[1], num_experts);
// NOTE: topk_ids actually holds i32 expert indices; DType has no I32, so
// this is a raw 4-byte buffer mislabeled F32. Never read it as floats —
// all consumers (weighted-sum / sparse GEMV kernels) cast to int*.
let topk_ids = Tensor::empty(&[num_tokens, top_k], DType::F32, router_logits.device());
let topk_weights = Tensor::empty(&[num_tokens, top_k], DType::F32, router_logits.device());
@@ -73,8 +146,10 @@ pub fn moe_topk_softmax(
router_logits.data_ptr() as *const c_void,
topk_ids.data_ptr() as *mut c_void,
topk_weights.data_ptr() as *mut c_void,
num_tokens as i32, num_experts as i32, top_k as i32,
std::ptr::null_mut(),
num_tokens as i32,
num_experts as i32,
top_k as i32,
xserv_cuda::current_stream_raw(),
);
}
@@ -88,14 +163,20 @@ pub fn moe_replicate(x: &Tensor, local_experts: usize) -> Tensor {
assert!(x.is_contiguous());
let num_tokens = x.shape()[0];
let hidden = x.shape()[1];
let out = Tensor::empty(&[local_experts, num_tokens, hidden], DType::BF16, x.device());
let out = Tensor::empty(
&[local_experts, num_tokens, hidden],
DType::BF16,
x.device(),
);
unsafe {
launch_moe_replicate_bf16(
x.data_ptr() as *const c_void,
out.data_ptr() as *mut c_void,
num_tokens as i32, hidden as i32, local_experts as i32,
std::ptr::null_mut(),
num_tokens as i32,
hidden as i32,
local_experts as i32,
xserv_cuda::current_stream_raw(),
);
}
@@ -117,8 +198,10 @@ pub fn moe_bias_add_3d(x: &Tensor, bias: &Tensor) {
launch_moe_bias_add_3d_bf16(
x.data_ptr() as *mut c_void,
bias.data_ptr() as *const c_void,
batch as i32, num_tokens as i32, dim as i32,
std::ptr::null_mut(),
batch as i32,
num_tokens as i32,
dim as i32,
xserv_cuda::current_stream_raw(),
);
}
}
@@ -149,15 +232,171 @@ pub fn moe_weighted_sum(
topk_ids.data_ptr() as *const c_void,
topk_weights.data_ptr() as *const c_void,
out.data_ptr() as *mut c_void,
num_tokens as i32, hidden as i32, top_k as i32,
expert_start as i32, local_experts as i32,
std::ptr::null_mut(),
num_tokens as i32,
hidden as i32,
top_k as i32,
expert_start as i32,
local_experts as i32,
xserv_cuda::current_stream_raw(),
);
}
out
}
/// Sparse MoE GEMV (FP8 W8A16): compute only the routed experts.
///
/// x: [num_tokens, K] BF16 (x_per_slot=false, gate_up) or
/// [num_tokens * top_k, K] BF16 (x_per_slot=true, down)
/// w_fp8_t: [local_experts, N, K] FP8E4M3 (transposed weight layout)
/// w_scales: [local_experts] F32 per-expert scalar scales
/// bias: [local_experts, N] BF16 (fused into the epilogue)
/// topk_ids: [num_tokens, top_k] i32 global expert ids (GPU)
///
/// Returns y [num_tokens, top_k, N] BF16. Slots routed to experts NOT
/// owned by this rank are left UNWRITTEN (uninitialized memory) — the
/// consumer must skip them (see moe_weighted_sum_sparse).
#[allow(clippy::too_many_arguments)]
pub fn moe_sparse_gemv_fp8(
x: &Tensor,
w_fp8_t: &Tensor,
w_scales: &Tensor,
bias: &Tensor,
topk_ids: &Tensor,
num_tokens: usize,
top_k: usize,
expert_start: usize,
local_experts: usize,
x_per_slot: bool,
) -> Tensor {
assert_eq!(x.dtype(), DType::BF16);
assert!(x.is_contiguous());
assert_eq!(w_fp8_t.dtype(), DType::FP8E4M3);
let n = w_fp8_t.shape()[1];
let k = w_fp8_t.shape()[2];
// The kernel reads weights as uint4 (16 FP8 values per lane) and would
// silently skip a K%16 tail.
assert_eq!(k % 16, 0, "sparse FP8 GEMV requires K % 16 == 0, got {k}");
assert_eq!(x.shape()[x.ndim() - 1], k);
assert_eq!(
x.shape()[0],
if x_per_slot {
num_tokens * top_k
} else {
num_tokens
}
);
let y = Tensor::empty(&[num_tokens, top_k, n], DType::BF16, x.device());
unsafe {
launch_moe_sparse_gemv_fp8_bf16(
x.data_ptr() as *const c_void,
w_fp8_t.data_ptr() as *const c_void,
w_scales.data_ptr() as *const c_void,
bias.data_ptr() as *const c_void,
topk_ids.data_ptr() as *const c_void,
y.data_ptr() as *mut c_void,
num_tokens as i32,
n as i32,
k as i32,
top_k as i32,
expert_start as i32,
local_experts as i32,
x_per_slot as i32,
xserv_cuda::current_stream_raw(),
);
}
y
}
/// Sparse MoE GEMV (MXFP4 W4A16): same contract as moe_sparse_gemv_fp8,
/// with packed 4-bit weights [E, N, K/2] + UE8M0 block scales [E, N, K/32].
#[allow(clippy::too_many_arguments)]
pub fn moe_sparse_gemv_mxfp4(
x: &Tensor,
w_packed: &Tensor,
w_scales: &Tensor,
bias: &Tensor,
topk_ids: &Tensor,
num_tokens: usize,
top_k: usize,
n: usize,
k: usize,
expert_start: usize,
local_experts: usize,
x_per_slot: bool,
) -> Tensor {
assert_eq!(x.dtype(), DType::BF16);
assert!(x.is_contiguous());
// 32-element MXFP4 blocks, read as uint4 (32 nibbles) per lane.
assert_eq!(k % 32, 0, "sparse MXFP4 GEMV requires K % 32 == 0, got {k}");
assert_eq!(x.shape()[x.ndim() - 1], k);
assert_eq!(
x.shape()[0],
if x_per_slot {
num_tokens * top_k
} else {
num_tokens
}
);
let y = Tensor::empty(&[num_tokens, top_k, n], DType::BF16, x.device());
unsafe {
launch_moe_sparse_gemv_mxfp4_bf16(
x.data_ptr() as *const c_void,
w_packed.data_ptr() as *const c_void,
w_scales.data_ptr() as *const c_void,
bias.data_ptr() as *const c_void,
topk_ids.data_ptr() as *const c_void,
y.data_ptr() as *mut c_void,
num_tokens as i32,
n as i32,
k as i32,
top_k as i32,
expert_start as i32,
local_experts as i32,
x_per_slot as i32,
xserv_cuda::current_stream_raw(),
);
}
y
}
/// Weighted sum over the slot axis of the sparse GEMV output.
///
/// down: [num_tokens, top_k, hidden] BF16 (non-local slots uninitialized
/// and skipped, never multiplied by zero — NaN * 0 = NaN).
pub fn moe_weighted_sum_sparse(
down: &Tensor,
topk_ids: &Tensor,
topk_weights: &Tensor,
expert_start: usize,
local_experts: usize,
) -> Tensor {
assert_eq!(down.ndim(), 3);
assert_eq!(down.dtype(), DType::BF16);
let num_tokens = down.shape()[0];
let top_k = down.shape()[1];
let hidden = down.shape()[2];
let out = Tensor::empty(&[num_tokens, hidden], DType::BF16, down.device());
unsafe {
launch_moe_weighted_sum_sparse_bf16(
down.data_ptr() as *const c_void,
topk_ids.data_ptr() as *const c_void,
topk_weights.data_ptr() as *const c_void,
out.data_ptr() as *mut c_void,
num_tokens as i32,
hidden as i32,
top_k as i32,
expert_start as i32,
local_experts as i32,
xserv_cuda::current_stream_raw(),
);
}
out
}
/// Strided batched GEMM for MoE expert forward.
/// C[b] = A[b] @ B[b] for b in 0..batch
///
@@ -202,16 +441,28 @@ pub fn batched_gemm_strided(a: &Tensor, b: &Tensor) -> Tensor {
let handle = cublas_handle();
unsafe {
cublasSetStream_v2(handle, std::ptr::null_mut());
cublasSetStream_v2(handle, xserv_cuda::current_stream_raw());
let status = cublasGemmStridedBatchedEx(
handle,
0, 0, // CUBLAS_OP_N, CUBLAS_OP_N
n as i32, m as i32, k as i32,
0,
0, // CUBLAS_OP_N, CUBLAS_OP_N
n as i32,
m as i32,
k as i32,
&alpha as *const f32 as *const c_void,
b.data_ptr() as *const c_void, CUDA_R_16BF, n as i32, stride_b,
a.data_ptr() as *const c_void, CUDA_R_16BF, k as i32, stride_a,
b.data_ptr() as *const c_void,
CUDA_R_16BF,
n as i32,
stride_b,
a.data_ptr() as *const c_void,
CUDA_R_16BF,
k as i32,
stride_a,
&beta as *const f32 as *const c_void,
c.data_ptr() as *mut c_void, CUDA_R_16BF, n as i32, stride_c,
c.data_ptr() as *mut c_void,
CUDA_R_16BF,
n as i32,
stride_c,
batch as i32,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT,

View File

@@ -13,30 +13,46 @@ unsafe extern "C" {
src: *const c_void,
scales: *const c_void,
dst: *mut c_void,
num_experts: i32, rows: i32, cols: i32,
num_experts: i32,
rows: i32,
cols: i32,
stream: *mut c_void,
);
fn launch_quantize_bf16_to_fp8e4m3_rowwise(
src: *const c_void,
dst: *mut c_void,
scales: *mut c_void,
num_rows: i32, cols: i32,
num_rows: i32,
cols: i32,
stream: *mut c_void,
);
fn launch_rowwise_scale_moe_bf16(
data: *mut c_void,
a_scales: *const c_void,
b_scales: *const c_void,
num_rows: i32, cols: i32, tokens: i32,
num_rows: i32,
cols: i32,
tokens: i32,
stream: *mut c_void,
);
fn launch_batched_gemv_mxfp4_bf16(
x: *const c_void, w_packed: *const c_void, w_scales: *const c_void, y: *mut c_void,
e: i32, n: i32, k: i32, stream: *mut c_void,
x: *const c_void,
w_packed: *const c_void,
w_scales: *const c_void,
y: *mut c_void,
e: i32,
n: i32,
k: i32,
stream: *mut c_void,
);
fn launch_dequant_mxfp4_to_bf16_t(
w_packed: *const c_void, w_scales: *const c_void, out: *mut c_void,
e: i32, n: i32, k: i32, stream: *mut c_void,
w_packed: *const c_void,
w_scales: *const c_void,
out: *mut c_void,
e: i32,
n: i32,
k: i32,
stream: *mut c_void,
);
}
@@ -66,34 +82,68 @@ struct CublasLtMatmulHeuristicResult {
unsafe extern "C" {
fn cublasLtCreate(handle: *mut CublasLtHandle) -> i32;
fn cublasLtDestroy(handle: CublasLtHandle) -> i32;
fn cublasLtMatmulDescCreate(desc: *mut CublasLtMatmulDesc, compute_type: i32, scale_type: i32) -> i32;
fn cublasLtMatmulDescCreate(
desc: *mut CublasLtMatmulDesc,
compute_type: i32,
scale_type: i32,
) -> i32;
fn cublasLtMatmulDescDestroy(desc: CublasLtMatmulDesc) -> i32;
fn cublasLtMatmulDescSetAttribute(desc: CublasLtMatmulDesc, attr: i32, buf: *const c_void, size: usize) -> i32;
fn cublasLtMatrixLayoutCreate(layout: *mut CublasLtMatrixLayout, dtype: i32, rows: u64, cols: u64, ld: i64) -> i32;
fn cublasLtMatmulDescSetAttribute(
desc: CublasLtMatmulDesc,
attr: i32,
buf: *const c_void,
size: usize,
) -> i32;
fn cublasLtMatrixLayoutCreate(
layout: *mut CublasLtMatrixLayout,
dtype: i32,
rows: u64,
cols: u64,
ld: i64,
) -> i32;
fn cublasLtMatrixLayoutDestroy(layout: CublasLtMatrixLayout) -> i32;
fn cublasLtMatrixLayoutSetAttribute(layout: CublasLtMatrixLayout, attr: i32, buf: *const c_void, size: usize) -> i32;
fn cublasLtMatrixLayoutSetAttribute(
layout: CublasLtMatrixLayout,
attr: i32,
buf: *const c_void,
size: usize,
) -> i32;
fn cublasLtMatmulPreferenceCreate(pref: *mut CublasLtMatmulPreference) -> i32;
fn cublasLtMatmulPreferenceDestroy(pref: CublasLtMatmulPreference) -> i32;
fn cublasLtMatmulPreferenceSetAttribute(pref: CublasLtMatmulPreference, attr: i32, buf: *const c_void, size: usize) -> i32;
fn cublasLtMatmulPreferenceSetAttribute(
pref: CublasLtMatmulPreference,
attr: i32,
buf: *const c_void,
size: usize,
) -> i32;
fn cublasLtMatmulAlgoGetHeuristic(
handle: CublasLtHandle, desc: CublasLtMatmulDesc,
a_layout: CublasLtMatrixLayout, b_layout: CublasLtMatrixLayout,
c_layout: CublasLtMatrixLayout, d_layout: CublasLtMatrixLayout,
handle: CublasLtHandle,
desc: CublasLtMatmulDesc,
a_layout: CublasLtMatrixLayout,
b_layout: CublasLtMatrixLayout,
c_layout: CublasLtMatrixLayout,
d_layout: CublasLtMatrixLayout,
pref: CublasLtMatmulPreference,
requested: i32,
results: *mut CublasLtMatmulHeuristicResult,
found: *mut i32,
) -> i32;
fn cublasLtMatmul(
handle: CublasLtHandle, desc: CublasLtMatmulDesc,
handle: CublasLtHandle,
desc: CublasLtMatmulDesc,
alpha: *const c_void,
a: *const c_void, a_layout: CublasLtMatrixLayout,
b: *const c_void, b_layout: CublasLtMatrixLayout,
a: *const c_void,
a_layout: CublasLtMatrixLayout,
b: *const c_void,
b_layout: CublasLtMatrixLayout,
beta: *const c_void,
c: *const c_void, c_layout: CublasLtMatrixLayout,
d: *mut c_void, d_layout: CublasLtMatrixLayout,
c: *const c_void,
c_layout: CublasLtMatrixLayout,
d: *mut c_void,
d_layout: CublasLtMatrixLayout,
algo: *const CublasLtMatmulAlgo,
workspace: *mut c_void, workspace_size: usize,
workspace: *mut c_void,
workspace_size: usize,
stream: *mut c_void,
) -> i32;
}
@@ -107,17 +157,11 @@ const CUDA_R_8F_E4M3: i32 = 28;
// MatmulDesc attributes
const CUBLASLT_MATMUL_DESC_A_SCALE_POINTER: i32 = 17;
const CUBLASLT_MATMUL_DESC_B_SCALE_POINTER: i32 = 18;
const CUBLASLT_MATMUL_DESC_A_SCALE_MODE: i32 = 31;
const CUBLASLT_MATMUL_DESC_B_SCALE_MODE: i32 = 32;
// MatrixLayout attributes
const CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT: i32 = 5;
const CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET: i32 = 6;
// Scale modes
const CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR: i32 = 0;
const CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F: i32 = 3;
// MatmulPreference attributes
const CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES: i32 = 1;
@@ -159,8 +203,15 @@ impl CublasLtContext {
assert_eq!(status, 0, "cublasLtCreate failed: {status}");
let workspace = GpuBuffer::alloc(WORKSPACE_BYTES).expect("alloc cublasLt workspace");
let mut one_buf = GpuBuffer::alloc(4).expect("alloc cublasLt fp8 scale");
one_buf.copy_from_host(&1.0f32.to_le_bytes()).expect("init fp8 scale");
Self { handle, workspace, one_buf, plans: HashMap::new() }
one_buf
.copy_from_host(&1.0f32.to_le_bytes())
.expect("init fp8 scale");
Self {
handle,
workspace,
one_buf,
plans: HashMap::new(),
}
}
/// Get the cached strided-batched plan for (m, n, k, batch), building it on
@@ -216,10 +267,25 @@ unsafe fn build_fp8_plan(
// transA=T (required for FP8 on Blackwell)
let trans_a: i32 = 1;
cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_a as *const i32 as _, 4);
cublasLtMatmulDescSetAttribute(
desc,
CUBLASLT_MATMUL_DESC_TRANSA,
&trans_a as *const i32 as _,
4,
);
let ptr_sz = std::mem::size_of::<*const c_void>();
cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &one_ptr as *const _ as _, ptr_sz);
cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &one_ptr as *const _ as _, ptr_sz);
cublasLtMatmulDescSetAttribute(
desc,
CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&one_ptr as *const _ as _,
ptr_sz,
);
cublasLtMatmulDescSetAttribute(
desc,
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&one_ptr as *const _ as _,
ptr_sz,
);
// Per-expert strides in ELEMENTS for the strided-batch layout.
let stride_a = (n * k) as i64; // weights [N, K]
@@ -227,10 +293,18 @@ unsafe fn build_fp8_plan(
let stride_c = (m * n) as i64; // output [M, N]
let bc = batch as i32;
let set_batch = |layout: CublasLtMatrixLayout, stride: i64| {
cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
&bc as *const i32 as _, 4);
cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
&stride as *const i64 as _, 8);
cublasLtMatrixLayoutSetAttribute(
layout,
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
&bc as *const i32 as _,
4,
);
cublasLtMatrixLayoutSetAttribute(
layout,
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
&stride as *const i64 as _,
8,
);
};
// "A" layout (weights, transposed): physical (K, N) col-major, ld=K
@@ -252,20 +326,39 @@ unsafe fn build_fp8_plan(
let mut pref: CublasLtMatmulPreference = std::ptr::null_mut();
cublasLtMatmulPreferenceCreate(&mut pref);
let ws_bytes = WORKSPACE_BYTES as u64;
cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &ws_bytes as *const u64 as _, 8);
cublasLtMatmulPreferenceSetAttribute(
pref,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&ws_bytes as *const u64 as _,
8,
);
let mut heuristic = std::mem::zeroed::<CublasLtMatmulHeuristicResult>();
let mut found: i32 = 0;
let status = cublasLtMatmulAlgoGetHeuristic(
handle, desc, a_layout, b_layout, c_layout, d_layout,
pref, 1, &mut heuristic, &mut found,
handle,
desc,
a_layout,
b_layout,
c_layout,
d_layout,
pref,
1,
&mut heuristic,
&mut found,
);
assert!(
status == 0 && found > 0,
"cublasLtMatmulAlgoGetHeuristic failed for batched FP8 GEMM (m={m}, n={n}, k={k}, batch={batch}): status={status}, found={found}"
);
assert!(status == 0 && found > 0,
"cublasLtMatmulAlgoGetHeuristic failed for batched FP8 GEMM (m={m}, n={n}, k={k}, batch={batch}): status={status}, found={found}");
cublasLtMatmulPreferenceDestroy(pref);
Fp8Plan {
desc, a_layout, b_layout, c_layout, d_layout,
desc,
a_layout,
b_layout,
c_layout,
d_layout,
algo: heuristic.algo,
workspace_size: heuristic.workspace_size,
}
@@ -305,8 +398,10 @@ pub fn dequant_fp8_to_bf16(src: &Tensor, scales: &Tensor) -> Tensor {
src.data_ptr() as *const c_void,
scales.data_ptr() as *const c_void,
out.data_ptr() as *mut c_void,
num_experts as i32, rows as i32, cols as i32,
std::ptr::null_mut(),
num_experts as i32,
rows as i32,
cols as i32,
xserv_cuda::current_stream_raw(),
);
}
@@ -335,8 +430,9 @@ pub fn quantize_bf16_to_fp8_rowwise(src: &Tensor) -> (Tensor, Tensor) {
src.data_ptr() as *const c_void,
fp8_out.data_ptr() as *mut c_void,
scales.data_ptr() as *mut c_void,
num_rows as i32, cols as i32,
std::ptr::null_mut(),
num_rows as i32,
cols as i32,
xserv_cuda::current_stream_raw(),
);
}
@@ -398,23 +494,27 @@ pub fn batched_gemm_fp8(
unsafe {
let status = cublasLtMatmul(
handle, plan.desc,
handle,
plan.desc,
&alpha as *const f32 as _,
b_fp8_t.data_ptr() as *const c_void, // cuBLASLt "A" = weights
plan.a_layout,
a_fp8.data_ptr() as *const c_void, // cuBLASLt "B" = activations
a_fp8.data_ptr() as *const c_void, // cuBLASLt "B" = activations
plan.b_layout,
&beta as *const f32 as _,
c.data_ptr() as *const c_void, // C (unused with beta=0)
c.data_ptr() as *const c_void, // C (unused with beta=0)
plan.c_layout,
c.data_ptr() as *mut c_void, // D = output
c.data_ptr() as *mut c_void, // D = output
plan.d_layout,
&plan.algo,
ws_ptr,
plan.workspace_size,
std::ptr::null_mut(),
xserv_cuda::current_stream_raw(),
);
assert_eq!(
status, 0,
"batched cublasLtMatmul FP8 failed: status={status}"
);
assert_eq!(status, 0, "batched cublasLtMatmul FP8 failed: status={status}");
}
});
@@ -429,8 +529,10 @@ pub fn batched_gemm_fp8(
c.data_ptr() as *mut c_void,
a_scales.data_ptr() as *const c_void,
b_scales.data_ptr() as *const c_void,
total_rows, n as i32, m as i32,
std::ptr::null_mut(),
total_rows,
n as i32,
m as i32,
xserv_cuda::current_stream_raw(),
);
}
@@ -448,7 +550,13 @@ pub fn batched_gemm_fp8(
/// w_scales: [E, N, K/32] byte tensor — UE8M0 scale per 32-element block
///
/// Returns: [E, N] BF16, where y[e,n] = sum_k x[e,k] * dequant(W[e,n,k]).
pub fn batched_gemv_mxfp4(x: &Tensor, w_packed: &Tensor, w_scales: &Tensor, n: usize, k: usize) -> Tensor {
pub fn batched_gemv_mxfp4(
x: &Tensor,
w_packed: &Tensor,
w_scales: &Tensor,
n: usize,
k: usize,
) -> Tensor {
assert_eq!(x.dtype(), DType::BF16);
assert!(x.is_contiguous());
let e = x.shape()[0];
@@ -461,8 +569,10 @@ pub fn batched_gemv_mxfp4(x: &Tensor, w_packed: &Tensor, w_scales: &Tensor, n: u
w_packed.data_ptr() as *const c_void,
w_scales.data_ptr() as *const c_void,
y.data_ptr() as *mut c_void,
e as i32, n as i32, k as i32,
std::ptr::null_mut(),
e as i32,
n as i32,
k as i32,
xserv_cuda::current_stream_raw(),
);
}
y
@@ -470,15 +580,23 @@ pub fn batched_gemv_mxfp4(x: &Tensor, w_packed: &Tensor, w_scales: &Tensor, n: u
/// Dequantize MXFP4 weights [E, N, K] → BF16 [E, K, N] for the prefill GEMM path
/// (the BF16 batched GEMM expects weights as [E, K, N]).
pub fn dequant_mxfp4_to_bf16_t(w_packed: &Tensor, w_scales: &Tensor, e: usize, n: usize, k: usize) -> Tensor {
pub fn dequant_mxfp4_to_bf16_t(
w_packed: &Tensor,
w_scales: &Tensor,
e: usize,
n: usize,
k: usize,
) -> Tensor {
let out = Tensor::empty(&[e, k, n], DType::BF16, w_packed.device());
unsafe {
launch_dequant_mxfp4_to_bf16_t(
w_packed.data_ptr() as *const c_void,
w_scales.data_ptr() as *const c_void,
out.data_ptr() as *mut c_void,
e as i32, n as i32, k as i32,
std::ptr::null_mut(),
e as i32,
n as i32,
k as i32,
xserv_cuda::current_stream_raw(),
);
}
out

View File

@@ -2,13 +2,35 @@ use std::ffi::c_void;
use xserv_tensor::{DType, Device, Tensor};
unsafe extern "C" {
fn launch_rmsnorm_f32(x: *const c_void, gamma: *const c_void, out: *mut c_void,
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
fn launch_rmsnorm_bf16(x: *const c_void, gamma: *const c_void, out: *mut c_void,
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
fn launch_add_rmsnorm_bf16(x: *const c_void, residual: *const c_void, gamma: *const c_void,
normed_out: *mut c_void, sum_out: *mut c_void,
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
fn launch_rmsnorm_f32(
x: *const c_void,
gamma: *const c_void,
out: *mut c_void,
rows: i32,
hidden_size: i32,
eps: f32,
stream: *mut c_void,
);
fn launch_rmsnorm_bf16(
x: *const c_void,
gamma: *const c_void,
out: *mut c_void,
rows: i32,
hidden_size: i32,
eps: f32,
stream: *mut c_void,
);
fn launch_add_rmsnorm_bf16(
x: *const c_void,
residual: *const c_void,
gamma: *const c_void,
normed_out: *mut c_void,
sum_out: *mut c_void,
rows: i32,
hidden_size: i32,
eps: f32,
stream: *mut c_void,
);
}
pub fn rmsnorm(x: &Tensor, gamma: &Tensor, eps: f32) -> Tensor {
@@ -20,19 +42,35 @@ pub fn rmsnorm(x: &Tensor, gamma: &Tensor, eps: f32) -> Tensor {
assert_eq!(x.dtype(), gamma.dtype());
let rows = x.numel() / hidden_size;
assert!(rows <= i32::MAX as usize, "too many rows for i32 kernel param");
assert!(hidden_size <= i32::MAX as usize, "hidden_size too large for i32 kernel param");
assert!(
rows <= i32::MAX as usize,
"too many rows for i32 kernel param"
);
assert!(
hidden_size <= i32::MAX as usize,
"hidden_size too large for i32 kernel param"
);
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
unsafe {
match x.dtype() {
DType::F32 => launch_rmsnorm_f32(
x.data_ptr() as _, gamma.data_ptr() as _, out.data_ptr() as *mut c_void,
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(),
x.data_ptr() as _,
gamma.data_ptr() as _,
out.data_ptr() as *mut c_void,
rows as i32,
hidden_size as i32,
eps,
xserv_cuda::current_stream_raw(),
),
DType::BF16 => launch_rmsnorm_bf16(
x.data_ptr() as _, gamma.data_ptr() as _, out.data_ptr() as *mut c_void,
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(),
x.data_ptr() as _,
gamma.data_ptr() as _,
out.data_ptr() as *mut c_void,
rows as i32,
hidden_size as i32,
eps,
xserv_cuda::current_stream_raw(),
),
_ => panic!("unsupported dtype for rmsnorm"),
}
@@ -56,8 +94,14 @@ pub fn add_rmsnorm(x: &Tensor, residual: &Tensor, gamma: &Tensor, eps: f32) -> (
assert_eq!(gamma.shape(), &[hidden_size]);
let rows = x.numel() / hidden_size;
assert!(rows <= i32::MAX as usize, "too many rows for i32 kernel param");
assert!(hidden_size <= i32::MAX as usize, "hidden_size too large for i32 kernel param");
assert!(
rows <= i32::MAX as usize,
"too many rows for i32 kernel param"
);
assert!(
hidden_size <= i32::MAX as usize,
"hidden_size too large for i32 kernel param"
);
let normed_out = Tensor::empty(x.shape(), DType::BF16, x.device());
let sum_out = Tensor::empty(x.shape(), DType::BF16, x.device());
@@ -71,7 +115,7 @@ pub fn add_rmsnorm(x: &Tensor, residual: &Tensor, gamma: &Tensor, eps: f32) -> (
rows as i32,
hidden_size as i32,
eps,
std::ptr::null_mut(),
xserv_cuda::current_stream_raw(),
);
}

View File

@@ -3,15 +3,34 @@ use xserv_cuda::GpuBuffer;
use xserv_tensor::{DType, Device, Tensor};
unsafe extern "C" {
fn launch_rope_f32(x: *mut c_void, cos_cache: *const c_void, sin_cache: *const c_void,
positions: *const c_void, num_tokens: i32, num_heads: i32,
head_dim: i32, stream: *mut c_void);
fn launch_rope_bf16(x: *mut c_void, cos_cache: *const c_void, sin_cache: *const c_void,
positions: *const c_void, num_tokens: i32, num_heads: i32,
head_dim: i32, stream: *mut c_void);
fn launch_compute_rope_cache(cos_cache: *mut c_void, sin_cache: *mut c_void,
max_seq_len: i32, half_dim: i32, theta: f32,
stream: *mut c_void);
fn launch_rope_f32(
x: *mut c_void,
cos_cache: *const c_void,
sin_cache: *const c_void,
positions: *const c_void,
num_tokens: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
);
fn launch_rope_bf16(
x: *mut c_void,
cos_cache: *const c_void,
sin_cache: *const c_void,
positions: *const c_void,
num_tokens: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
);
fn launch_compute_rope_cache(
cos_cache: *mut c_void,
sin_cache: *mut c_void,
max_seq_len: i32,
half_dim: i32,
theta: f32,
stream: *mut c_void,
);
}
pub struct RopeCache {
@@ -30,12 +49,21 @@ impl RopeCache {
unsafe {
launch_compute_rope_cache(
cos.as_mut_ptr() as _, sin.as_mut_ptr() as _,
max_seq_len as i32, half_dim as i32, theta, std::ptr::null_mut(),
cos.as_mut_ptr() as _,
sin.as_mut_ptr() as _,
max_seq_len as i32,
half_dim as i32,
theta,
xserv_cuda::current_stream_raw(),
);
}
Self { cos, sin, max_seq_len, half_dim }
Self {
cos,
sin,
max_seq_len,
half_dim,
}
}
/// YaRN (Yet another RoPE extensioN) RoPE cache. Applies frequency-dependent
@@ -68,8 +96,8 @@ impl RopeCache {
let mut inv_freq = vec![0.0f64; half_dim];
for i in 0..half_dim {
let pos_freq = theta.powf((2 * i) as f64 / dim);
let inv_freq_extrapolation = 1.0 / pos_freq; // original
let inv_freq_interpolation = 1.0 / (factor * pos_freq); // scaled
let inv_freq_extrapolation = 1.0 / pos_freq; // original
let inv_freq_interpolation = 1.0 / (factor * pos_freq); // scaled
// Linear ramp: 0 where we keep original, 1 where we interpolate
let ramp = if (high - low).abs() < 0.001 {
@@ -101,16 +129,19 @@ impl RopeCache {
let nbytes = total * std::mem::size_of::<f32>();
let mut cos = GpuBuffer::alloc(nbytes).expect("alloc yarn cos_cache");
let mut sin = GpuBuffer::alloc(nbytes).expect("alloc yarn sin_cache");
let cos_bytes = unsafe {
std::slice::from_raw_parts(cos_host.as_ptr() as *const u8, nbytes)
};
let sin_bytes = unsafe {
std::slice::from_raw_parts(sin_host.as_ptr() as *const u8, nbytes)
};
let cos_bytes =
unsafe { std::slice::from_raw_parts(cos_host.as_ptr() as *const u8, nbytes) };
let sin_bytes =
unsafe { std::slice::from_raw_parts(sin_host.as_ptr() as *const u8, nbytes) };
cos.copy_from_host(cos_bytes).unwrap();
sin.copy_from_host(sin_bytes).unwrap();
Self { cos, sin, max_seq_len, half_dim }
Self {
cos,
sin,
max_seq_len,
half_dim,
}
}
}
@@ -133,24 +164,46 @@ pub fn rope_inplace(x: &Tensor, cache: &RopeCache, positions: &[u32]) {
num_tokens * std::mem::size_of::<u32>(),
)
};
let mut pos_gpu = xserv_cuda::allocator::cached_alloc(pos_bytes.len()).expect("alloc positions");
let mut pos_gpu =
xserv_cuda::allocator::cached_alloc(pos_bytes.len()).expect("alloc positions");
pos_gpu.copy_from_host(pos_bytes).unwrap();
rope_inplace_device_pos(x, cache, pos_gpu.as_ptr() as *const c_void);
}
/// RoPE in-place with positions already on the GPU (u32, [num_tokens]).
/// Used by the CUDA-graph decode path, where the position lives in a
/// persistent device buffer updated outside the captured region.
pub fn rope_inplace_device_pos(x: &Tensor, cache: &RopeCache, pos_gpu: *const c_void) {
assert_eq!(x.ndim(), 3);
assert!(x.is_contiguous());
assert!(matches!(x.device(), Device::Cuda(_)));
let num_tokens = x.shape()[0];
let num_heads = x.shape()[1];
let head_dim = x.shape()[2];
assert_eq!(head_dim / 2, cache.half_dim);
unsafe {
match x.dtype() {
DType::F32 => launch_rope_f32(
x.data_ptr() as *mut c_void,
cache.cos.as_ptr() as _, cache.sin.as_ptr() as _,
pos_gpu.as_ptr() as _,
num_tokens as i32, num_heads as i32, head_dim as i32,
std::ptr::null_mut(),
cache.cos.as_ptr() as _,
cache.sin.as_ptr() as _,
pos_gpu,
num_tokens as i32,
num_heads as i32,
head_dim as i32,
xserv_cuda::current_stream_raw(),
),
DType::BF16 => launch_rope_bf16(
x.data_ptr() as *mut c_void,
cache.cos.as_ptr() as _, cache.sin.as_ptr() as _,
pos_gpu.as_ptr() as _,
num_tokens as i32, num_heads as i32, head_dim as i32,
std::ptr::null_mut(),
cache.cos.as_ptr() as _,
cache.sin.as_ptr() as _,
pos_gpu,
num_tokens as i32,
num_heads as i32,
head_dim as i32,
xserv_cuda::current_stream_raw(),
),
_ => panic!("unsupported dtype for rope"),
}

View File

@@ -2,8 +2,20 @@ use std::ffi::c_void;
use xserv_tensor::{DType, Device, Tensor};
unsafe extern "C" {
fn launch_softmax_f32(x: *const c_void, out: *mut c_void, rows: i32, cols: i32, stream: *mut c_void);
fn launch_softmax_bf16(x: *const c_void, out: *mut c_void, rows: i32, cols: i32, stream: *mut c_void);
fn launch_softmax_f32(
x: *const c_void,
out: *mut c_void,
rows: i32,
cols: i32,
stream: *mut c_void,
);
fn launch_softmax_bf16(
x: *const c_void,
out: *mut c_void,
rows: i32,
cols: i32,
stream: *mut c_void,
);
}
/// Softmax along the last dimension.
@@ -14,19 +26,31 @@ pub fn softmax(x: &Tensor) -> Tensor {
let cols = *x.shape().last().unwrap();
let rows = x.numel() / cols;
assert!(rows <= i32::MAX as usize, "too many rows for i32 kernel param");
assert!(cols <= i32::MAX as usize, "cols too large for i32 kernel param");
assert!(
rows <= i32::MAX as usize,
"too many rows for i32 kernel param"
);
assert!(
cols <= i32::MAX as usize,
"cols too large for i32 kernel param"
);
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
unsafe {
match x.dtype() {
DType::F32 => launch_softmax_f32(
x.data_ptr() as _, out.data_ptr() as *mut c_void,
rows as i32, cols as i32, std::ptr::null_mut(),
x.data_ptr() as _,
out.data_ptr() as *mut c_void,
rows as i32,
cols as i32,
xserv_cuda::current_stream_raw(),
),
DType::BF16 => launch_softmax_bf16(
x.data_ptr() as _, out.data_ptr() as *mut c_void,
rows as i32, cols as i32, std::ptr::null_mut(),
x.data_ptr() as _,
out.data_ptr() as *mut c_void,
rows as i32,
cols as i32,
xserv_cuda::current_stream_raw(),
),
_ => panic!("unsupported dtype for softmax"),
}

View File

@@ -2,19 +2,79 @@ use std::ffi::c_void;
use xserv_tensor::{DType, Device, Tensor};
unsafe extern "C" {
fn launch_reshape_heads_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
fn launch_merge_heads_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
fn launch_transpose_hsd_to_shd_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
fn launch_transpose_shd_to_hsd_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
fn launch_repeat_kv_bf16(inp: *const c_void, out: *mut c_void, kv_heads: i32, n_rep: i32, seq_len: i32, head_dim: i32, stream: *mut c_void);
fn launch_strided_copy_bf16(inp: *const c_void, out: *mut c_void, numel: i32, ndim: i32,
shape0: i32, shape1: i32, shape2: i32, shape3: i32,
in_stride0: i32, in_stride1: i32, in_stride2: i32, in_stride3: i32,
in_offset: i32, stream: *mut c_void);
fn launch_strided_copy_f32(inp: *const c_void, out: *mut c_void, numel: i32, ndim: i32,
shape0: i32, shape1: i32, shape2: i32, shape3: i32,
in_stride0: i32, in_stride1: i32, in_stride2: i32, in_stride3: i32,
in_offset: i32, stream: *mut c_void);
fn launch_reshape_heads_bf16(
inp: *const c_void,
out: *mut c_void,
seq_len: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
);
fn launch_merge_heads_bf16(
inp: *const c_void,
out: *mut c_void,
seq_len: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
);
fn launch_transpose_hsd_to_shd_bf16(
inp: *const c_void,
out: *mut c_void,
seq_len: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
);
fn launch_transpose_shd_to_hsd_bf16(
inp: *const c_void,
out: *mut c_void,
seq_len: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
);
fn launch_repeat_kv_bf16(
inp: *const c_void,
out: *mut c_void,
kv_heads: i32,
n_rep: i32,
seq_len: i32,
head_dim: i32,
stream: *mut c_void,
);
fn launch_strided_copy_bf16(
inp: *const c_void,
out: *mut c_void,
numel: i32,
ndim: i32,
shape0: i32,
shape1: i32,
shape2: i32,
shape3: i32,
in_stride0: i32,
in_stride1: i32,
in_stride2: i32,
in_stride3: i32,
in_offset: i32,
stream: *mut c_void,
);
fn launch_strided_copy_f32(
inp: *const c_void,
out: *mut c_void,
numel: i32,
ndim: i32,
shape0: i32,
shape1: i32,
shape2: i32,
shape3: i32,
in_stride0: i32,
in_stride1: i32,
in_stride2: i32,
in_stride3: i32,
in_offset: i32,
stream: *mut c_void,
);
}
/// [S, H*D] → [1, H, S, D] on GPU (BF16)
@@ -24,8 +84,12 @@ pub fn reshape_heads_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim:
let out = Tensor::empty(&[1, num_heads, seq_len, head_dim], DType::BF16, x.device());
unsafe {
launch_reshape_heads_bf16(
x.data_ptr() as _, out.data_ptr() as *mut c_void,
seq_len as i32, num_heads as i32, head_dim as i32, std::ptr::null_mut(),
x.data_ptr() as _,
out.data_ptr() as *mut c_void,
seq_len as i32,
num_heads as i32,
head_dim as i32,
xserv_cuda::current_stream_raw(),
);
}
out
@@ -39,36 +103,58 @@ pub fn merge_heads_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: u
let out = Tensor::empty(&[seq_len, hidden], DType::BF16, x.device());
unsafe {
launch_merge_heads_bf16(
x.data_ptr() as _, out.data_ptr() as *mut c_void,
seq_len as i32, num_heads as i32, head_dim as i32, std::ptr::null_mut(),
x.data_ptr() as _,
out.data_ptr() as *mut c_void,
seq_len as i32,
num_heads as i32,
head_dim as i32,
xserv_cuda::current_stream_raw(),
);
}
out
}
/// [1, H, S, D] → [S, H, D] for RoPE on GPU (BF16)
pub fn transpose_for_rope_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: usize) -> Tensor {
pub fn transpose_for_rope_gpu(
x: &Tensor,
seq_len: usize,
num_heads: usize,
head_dim: usize,
) -> Tensor {
assert_eq!(x.dtype(), DType::BF16);
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
let out = Tensor::empty(&[seq_len, num_heads, head_dim], DType::BF16, x.device());
unsafe {
launch_transpose_hsd_to_shd_bf16(
x.data_ptr() as _, out.data_ptr() as *mut c_void,
seq_len as i32, num_heads as i32, head_dim as i32, std::ptr::null_mut(),
x.data_ptr() as _,
out.data_ptr() as *mut c_void,
seq_len as i32,
num_heads as i32,
head_dim as i32,
xserv_cuda::current_stream_raw(),
);
}
out
}
/// [S, H, D] → [1, H, S, D] after RoPE on GPU (BF16)
pub fn transpose_from_rope_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: usize) -> Tensor {
pub fn transpose_from_rope_gpu(
x: &Tensor,
seq_len: usize,
num_heads: usize,
head_dim: usize,
) -> Tensor {
assert_eq!(x.dtype(), DType::BF16);
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
let out = Tensor::empty(&[1, num_heads, seq_len, head_dim], DType::BF16, x.device());
unsafe {
launch_transpose_shd_to_hsd_bf16(
x.data_ptr() as _, out.data_ptr() as *mut c_void,
seq_len as i32, num_heads as i32, head_dim as i32, std::ptr::null_mut(),
x.data_ptr() as _,
out.data_ptr() as *mut c_void,
seq_len as i32,
num_heads as i32,
head_dim as i32,
xserv_cuda::current_stream_raw(),
);
}
out
@@ -76,7 +162,9 @@ pub fn transpose_from_rope_gpu(x: &Tensor, seq_len: usize, num_heads: usize, hea
/// [1, KV_H, S, D] → [1, KV_H*n_rep, S, D] on GPU (BF16)
pub fn repeat_kv_gpu(x: &Tensor, n_rep: usize) -> Tensor {
if n_rep == 1 { return x.clone(); }
if n_rep == 1 {
return x.clone();
}
assert_eq!(x.dtype(), DType::BF16);
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
let kv_heads = x.shape()[1];
@@ -86,8 +174,13 @@ pub fn repeat_kv_gpu(x: &Tensor, n_rep: usize) -> Tensor {
let out = Tensor::empty(&[1, new_heads, seq_len, head_dim], DType::BF16, x.device());
unsafe {
launch_repeat_kv_bf16(
x.data_ptr() as _, out.data_ptr() as *mut c_void,
kv_heads as i32, n_rep as i32, seq_len as i32, head_dim as i32, std::ptr::null_mut(),
x.data_ptr() as _,
out.data_ptr() as *mut c_void,
kv_heads as i32,
n_rep as i32,
seq_len as i32,
head_dim as i32,
xserv_cuda::current_stream_raw(),
);
}
out
@@ -122,20 +215,41 @@ pub fn strided_to_contiguous_gpu(x: &Tensor) -> Tensor {
unsafe {
match x.dtype() {
DType::BF16 => launch_strided_copy_bf16(
storage_ptr as _, out.data_ptr() as *mut c_void,
numel as i32, ndim as i32,
shape4[0], shape4[1], shape4[2], shape4[3],
strides4[0], strides4[1], strides4[2], strides4[3],
in_offset, std::ptr::null_mut(),
storage_ptr as _,
out.data_ptr() as *mut c_void,
numel as i32,
ndim as i32,
shape4[0],
shape4[1],
shape4[2],
shape4[3],
strides4[0],
strides4[1],
strides4[2],
strides4[3],
in_offset,
xserv_cuda::current_stream_raw(),
),
DType::F32 => launch_strided_copy_f32(
storage_ptr as _, out.data_ptr() as *mut c_void,
numel as i32, ndim as i32,
shape4[0], shape4[1], shape4[2], shape4[3],
strides4[0], strides4[1], strides4[2], strides4[3],
in_offset, std::ptr::null_mut(),
storage_ptr as _,
out.data_ptr() as *mut c_void,
numel as i32,
ndim as i32,
shape4[0],
shape4[1],
shape4[2],
shape4[3],
strides4[0],
strides4[1],
strides4[2],
strides4[3],
in_offset,
xserv_cuda::current_stream_raw(),
),
_ => panic!(
"strided_to_contiguous_gpu: unsupported dtype {:?}",
x.dtype()
),
_ => panic!("strided_to_contiguous_gpu: unsupported dtype {:?}", x.dtype()),
}
}
out

View File

@@ -1,11 +1,21 @@
use xserv_kernels::*;
use xserv_tensor::{Device, Tensor};
fn init() { xserv_cuda::device::set_device(0).unwrap(); }
fn init() {
xserv_cuda::device::set_device(0).unwrap();
}
fn cpu_attention(q: &[f32], k: &[f32], v: &[f32],
batch: usize, heads: usize, q_len: usize, kv_len: usize, head_dim: usize,
causal: bool) -> Vec<f32> {
fn cpu_attention(
q: &[f32],
k: &[f32],
v: &[f32],
batch: usize,
heads: usize,
q_len: usize,
kv_len: usize,
head_dim: usize,
causal: bool,
) -> Vec<f32> {
let mut out = vec![0.0f32; batch * heads * q_len * head_dim];
let scale = 1.0 / (head_dim as f32).sqrt();
@@ -70,8 +80,13 @@ fn check_close(a: &[f32], b: &[f32], atol: f32, name: &str) {
let mut max_err = 0.0f32;
for (i, (x, y)) in a.iter().zip(b).enumerate() {
let err = (x - y).abs();
if err > max_err { max_err = err; }
assert!(err <= atol, "{name}: mismatch at [{i}]: got {x}, expected {y}, err {err}");
if err > max_err {
max_err = err;
}
assert!(
err <= atol,
"{name}: mismatch at [{i}]: got {x}, expected {y}, err {err}"
);
}
println!("{name}: max_err = {max_err:.6e}");
}
@@ -105,7 +120,9 @@ fn test_batched_matmul() {
for i in 0..m {
for j in 0..n {
let mut s = 0.0f32;
for kk in 0..k { s += a_cpu[i * k + kk] * b_cpu[kk * n + j]; }
for kk in 0..k {
s += a_cpu[i * k + kk] * b_cpu[kk * n + j];
}
expected[i * n + j] = s;
}
}
@@ -116,7 +133,10 @@ fn test_batched_matmul() {
#[test]
fn test_attention_no_causal() {
init();
let b = 1; let h = 2; let s = 8; let d = 16;
let b = 1;
let h = 2;
let s = 8;
let d = 16;
let q_data = make_data(b * h * s * d);
let k_data = make_data(b * h * s * d);
let v_data = make_data(b * h * s * d);
@@ -126,13 +146,21 @@ fn test_attention_no_causal() {
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
let out = attention(&q, &k, &v, false).to_device(Device::Cpu);
check_close(out.as_slice::<f32>(), &expected, 1e-4, "attention_no_causal");
check_close(
out.as_slice::<f32>(),
&expected,
1e-4,
"attention_no_causal",
);
}
#[test]
fn test_attention_causal() {
init();
let b = 1; let h = 2; let s = 16; let d = 32;
let b = 1;
let h = 2;
let s = 16;
let d = 32;
let q_data = make_data(b * h * s * d);
let k_data = make_data(b * h * s * d);
let v_data = make_data(b * h * s * d);
@@ -148,7 +176,10 @@ fn test_attention_causal() {
#[test]
fn test_attention_causal_larger() {
init();
let b = 2; let h = 4; let s = 64; let d = 64;
let b = 2;
let h = 4;
let s = 64;
let d = 64;
let q_data = make_data(b * h * s * d);
let k_data = make_data(b * h * s * d);
let v_data = make_data(b * h * s * d);
@@ -158,18 +189,28 @@ fn test_attention_causal_larger() {
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
let out = attention(&q, &k, &v, true).to_device(Device::Cpu);
check_close(out.as_slice::<f32>(), &expected, 1e-2, "attention_causal_larger");
check_close(
out.as_slice::<f32>(),
&expected,
1e-2,
"attention_causal_larger",
);
}
#[test]
fn test_attention_causal_first_row_sees_only_first_token() {
init();
let b = 1; let h = 1; let s = 4; let d = 8;
let b = 1;
let h = 1;
let s = 4;
let d = 8;
let q_data = make_data(b * h * s * d);
let k_data = make_data(b * h * s * d);
let v_data: Vec<f32> = (0..s * d).map(|i| {
if i < d { 1.0 } else { 0.0 } // only first V row is nonzero
}).collect();
let v_data: Vec<f32> = (0..s * d)
.map(|i| {
if i < d { 1.0 } else { 0.0 } // only first V row is nonzero
})
.collect();
let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0));
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
@@ -181,7 +222,11 @@ fn test_attention_causal_first_row_sees_only_first_token() {
// output[0] should be exactly V[0] = [1, 1, 1, ...1]
let result = out.as_slice::<f32>();
for i in 0..d {
assert!((result[i] - 1.0).abs() < 1e-5,
"first row should equal V[0], got {} at dim {}", result[i], i);
assert!(
(result[i] - 1.0).abs() < 1e-5,
"first row should equal V[0], got {} at dim {}",
result[i],
i
);
}
}

View File

@@ -1,5 +1,5 @@
use half::bf16;
use xserv_kernels::{matmul, GemmBackend};
use xserv_kernels::{GemmBackend, matmul};
use xserv_tensor::{Device, Tensor};
fn cpu_matmul_f32(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Vec<f32> {
@@ -75,70 +75,110 @@ fn run_gemm_test_bf16(backend: GemmBackend, m: usize, n: usize, k: usize) {
// --- F32 tests ---
#[test]
fn test_gemm_naive_f32_small() { run_gemm_test_f32(GemmBackend::Naive, 4, 4, 4); }
fn test_gemm_naive_f32_small() {
run_gemm_test_f32(GemmBackend::Naive, 4, 4, 4);
}
#[test]
fn test_gemm_naive_f32_medium() { run_gemm_test_f32(GemmBackend::Naive, 64, 64, 64); }
fn test_gemm_naive_f32_medium() {
run_gemm_test_f32(GemmBackend::Naive, 64, 64, 64);
}
#[test]
fn test_gemm_naive_f32_rect() { run_gemm_test_f32(GemmBackend::Naive, 32, 64, 48); }
fn test_gemm_naive_f32_rect() {
run_gemm_test_f32(GemmBackend::Naive, 32, 64, 48);
}
#[test]
fn test_gemm_tiled_f32_small() { run_gemm_test_f32(GemmBackend::Tiled, 4, 4, 4); }
fn test_gemm_tiled_f32_small() {
run_gemm_test_f32(GemmBackend::Tiled, 4, 4, 4);
}
#[test]
fn test_gemm_tiled_f32_medium() { run_gemm_test_f32(GemmBackend::Tiled, 128, 128, 128); }
fn test_gemm_tiled_f32_medium() {
run_gemm_test_f32(GemmBackend::Tiled, 128, 128, 128);
}
#[test]
fn test_gemm_tiled_f32_rect() { run_gemm_test_f32(GemmBackend::Tiled, 65, 33, 97); }
fn test_gemm_tiled_f32_rect() {
run_gemm_test_f32(GemmBackend::Tiled, 65, 33, 97);
}
#[test]
fn test_gemm_cublas_f32_small() { run_gemm_test_f32(GemmBackend::CuBlas, 4, 4, 4); }
fn test_gemm_cublas_f32_small() {
run_gemm_test_f32(GemmBackend::CuBlas, 4, 4, 4);
}
#[test]
fn test_gemm_cublas_f32_medium() { run_gemm_test_f32(GemmBackend::CuBlas, 256, 256, 256); }
fn test_gemm_cublas_f32_medium() {
run_gemm_test_f32(GemmBackend::CuBlas, 256, 256, 256);
}
#[test]
fn test_gemm_cublas_f32_rect() { run_gemm_test_f32(GemmBackend::CuBlas, 65, 33, 97); }
fn test_gemm_cublas_f32_rect() {
run_gemm_test_f32(GemmBackend::CuBlas, 65, 33, 97);
}
// --- BF16 tests ---
#[test]
fn test_gemm_naive_bf16_small() { run_gemm_test_bf16(GemmBackend::Naive, 4, 4, 4); }
fn test_gemm_naive_bf16_small() {
run_gemm_test_bf16(GemmBackend::Naive, 4, 4, 4);
}
#[test]
fn test_gemm_naive_bf16_medium() { run_gemm_test_bf16(GemmBackend::Naive, 64, 64, 64); }
fn test_gemm_naive_bf16_medium() {
run_gemm_test_bf16(GemmBackend::Naive, 64, 64, 64);
}
#[test]
fn test_gemm_tiled_bf16_small() { run_gemm_test_bf16(GemmBackend::Tiled, 4, 4, 4); }
fn test_gemm_tiled_bf16_small() {
run_gemm_test_bf16(GemmBackend::Tiled, 4, 4, 4);
}
#[test]
fn test_gemm_tiled_bf16_medium() { run_gemm_test_bf16(GemmBackend::Tiled, 128, 128, 128); }
fn test_gemm_tiled_bf16_medium() {
run_gemm_test_bf16(GemmBackend::Tiled, 128, 128, 128);
}
#[test]
fn test_gemm_cublas_bf16_small() { run_gemm_test_bf16(GemmBackend::CuBlas, 4, 4, 4); }
fn test_gemm_cublas_bf16_small() {
run_gemm_test_bf16(GemmBackend::CuBlas, 4, 4, 4);
}
#[test]
fn test_gemm_cublas_bf16_medium() { run_gemm_test_bf16(GemmBackend::CuBlas, 256, 256, 256); }
fn test_gemm_cublas_bf16_medium() {
run_gemm_test_bf16(GemmBackend::CuBlas, 256, 256, 256);
}
// --- Custom GEMV tests (M=1, BF16 fast path) ---
#[test]
fn test_gemv_bf16_small() { run_gemm_test_bf16(GemmBackend::CuBlas, 1, 64, 64); }
fn test_gemv_bf16_small() {
run_gemm_test_bf16(GemmBackend::CuBlas, 1, 64, 64);
}
#[test]
fn test_gemv_bf16_medium() { run_gemm_test_bf16(GemmBackend::CuBlas, 1, 256, 256); }
fn test_gemv_bf16_medium() {
run_gemm_test_bf16(GemmBackend::CuBlas, 1, 256, 256);
}
#[test]
fn test_gemv_bf16_4096() { run_gemm_test_bf16(GemmBackend::CuBlas, 1, 4096, 4096); }
fn test_gemv_bf16_4096() {
run_gemm_test_bf16(GemmBackend::CuBlas, 1, 4096, 4096);
}
#[test]
fn test_gemv_bf16_rect() { run_gemm_test_bf16(GemmBackend::CuBlas, 1, 512, 4096); }
fn test_gemv_bf16_rect() {
run_gemm_test_bf16(GemmBackend::CuBlas, 1, 512, 4096);
}
// --- Larger benchmark-style tests ---
#[test]
fn test_gemm_cublas_f32_1024() { run_gemm_test_f32(GemmBackend::CuBlas, 1024, 1024, 1024); }
fn test_gemm_cublas_f32_1024() {
run_gemm_test_f32(GemmBackend::CuBlas, 1024, 1024, 1024);
}
#[test]
fn test_gemm_consistency_all_backends() {

View File

@@ -2,7 +2,9 @@ use half::bf16;
use xserv_kernels::*;
use xserv_tensor::{Device, Tensor};
fn init() { xserv_cuda::device::set_device(0).unwrap(); }
fn init() {
xserv_cuda::device::set_device(0).unwrap();
}
// --- CPU reference implementations ---
@@ -37,10 +39,12 @@ fn cpu_layernorm(x: &[f32], gamma: &[f32], beta: &[f32], eps: f32, hidden: usize
fn cpu_gelu(x: &[f32]) -> Vec<f32> {
let sqrt_2_over_pi = 0.7978845608f32;
x.iter().map(|&v| {
let inner = sqrt_2_over_pi * (v + 0.044715 * v * v * v);
0.5 * v * (1.0 + inner.tanh())
}).collect()
x.iter()
.map(|&v| {
let inner = sqrt_2_over_pi * (v + 0.044715 * v * v * v);
0.5 * v * (1.0 + inner.tanh())
})
.collect()
}
fn cpu_silu(x: &[f32]) -> Vec<f32> {
@@ -88,8 +92,13 @@ fn check_close(result: &[f32], expected: &[f32], atol: f32, name: &str) {
let mut max_err = 0.0f32;
for (i, (r, e)) in result.iter().zip(expected).enumerate() {
let err = (r - e).abs();
if err > max_err { max_err = err; }
assert!(err <= atol, "{name}: mismatch at [{i}]: got {r}, expected {e}, err {err}");
if err > max_err {
max_err = err;
}
assert!(
err <= atol,
"{name}: mismatch at [{i}]: got {r}, expected {e}, err {err}"
);
}
println!("{name}: max_err = {max_err:.6e}");
}
@@ -208,13 +217,18 @@ fn test_softmax_sum_to_one() {
init();
let rows = 4;
let cols = 2048;
let data: Vec<f32> = (0..rows * cols).map(|i| ((i % 31) as f32 - 15.0) * 0.5).collect();
let data: Vec<f32> = (0..rows * cols)
.map(|i| ((i % 31) as f32 - 15.0) * 0.5)
.collect();
let x = Tensor::from_slice(&data, &[rows, cols]).to_device(Device::Cuda(0));
let out = softmax(&x).to_device(Device::Cpu);
let result = out.as_slice::<f32>();
for r in 0..rows {
let row_sum: f32 = result[r * cols..(r + 1) * cols].iter().sum();
assert!((row_sum - 1.0).abs() < 1e-5, "softmax row {r} sum = {row_sum}");
assert!(
(row_sum - 1.0).abs() < 1e-5,
"softmax row {r} sum = {row_sum}"
);
}
}
@@ -247,8 +261,10 @@ fn test_embedding_f32() {
for i in 0..hidden {
let expected = table_data[tid as usize * hidden + i];
let got = result[seq_idx * hidden + i];
assert!((got - expected).abs() < 1e-6,
"embedding mismatch at [{seq_idx},{i}]: got {got}, expected {expected}");
assert!(
(got - expected).abs() < 1e-6,
"embedding mismatch at [{seq_idx},{i}]: got {got}, expected {expected}"
);
}
}
}
@@ -270,8 +286,8 @@ fn test_rope_f32() {
let mut expected = x_data.clone();
cpu_rope(&mut expected, &positions, num_heads, head_dim, theta);
let x = Tensor::from_slice(&x_data, &[num_tokens, num_heads, head_dim])
.to_device(Device::Cuda(0));
let x =
Tensor::from_slice(&x_data, &[num_tokens, num_heads, head_dim]).to_device(Device::Cuda(0));
let cache = RopeCache::new(64, head_dim, theta);
rope_inplace(&x, &cache, &positions);
@@ -292,8 +308,8 @@ fn test_rope_position_0_identity() {
.map(|i| (i as f32 + 1.0) * 0.1)
.collect();
let x = Tensor::from_slice(&x_data, &[num_tokens, num_heads, head_dim])
.to_device(Device::Cuda(0));
let x =
Tensor::from_slice(&x_data, &[num_tokens, num_heads, head_dim]).to_device(Device::Cuda(0));
let cache = RopeCache::new(64, head_dim, 10000.0);
rope_inplace(&x, &cache, &positions);

View File

@@ -3,7 +3,7 @@ use std::sync::Arc;
use std::time::Instant;
use xserv_distributed::{TpContext, UniqueId, get_unique_id};
use xserv_model::{loader, GptOss, ModelConfig, PagedKVCache, BLOCK_SIZE};
use xserv_model::{BLOCK_SIZE, GptOss, GraphedGptOssDecoder, ModelConfig, PagedKVCache, loader};
use xserv_tensor::{DType, Device};
use xserv_tokenizer::Tokenizer;
@@ -23,8 +23,12 @@ fn main() {
eprintln!(
"gpt-oss-20b: layers={}, hidden={}, heads={}/{} kv, experts={}, top_k={}, vocab={}",
config.num_layers(), config.hidden(), config.num_heads(),
config.num_kv_heads(), config.num_experts(), config.experts_per_token(),
config.num_layers(),
config.hidden(),
config.num_heads(),
config.num_kv_heads(),
config.num_experts(),
config.experts_per_token(),
config.vocab_size
);
eprintln!("TP world={world}, max_tokens={max_tokens}");
@@ -59,17 +63,29 @@ fn main() {
let tp0 = Arc::new(TpContext::init(0, world, uid, 0));
eprintln!("[rank 0] Loading weights...");
let weights = loader::load_model_dir(&model_dir, Device::Cpu);
eprintln!("[rank 0] Loaded {} tensors, building model...", weights.len());
eprintln!(
"[rank 0] Loaded {} tensors, building model...",
weights.len()
);
let model = GptOss::from_weights_tp(config.clone(), weights, 0, world, 0, Some(tp0));
let total_blocks = max_blocks_per_seq + 64;
let mut cache = PagedKVCache::new_tp(
&config, local_kv, total_blocks, 0, 4, max_blocks_per_seq, DType::BF16, 0,
&config,
local_kv,
total_blocks,
0,
4,
max_blocks_per_seq,
DType::BF16,
0,
);
eprintln!("[rank 0] Ready.");
// Prompt
let prompt_arg = get_arg::<String>(&args, "--prompt");
let prompt = prompt_arg.as_deref().unwrap_or("What is the meaning of life?");
let prompt = prompt_arg
.as_deref()
.unwrap_or("What is the meaning of life?");
let token_ids = tokenizer.encode(prompt);
eprintln!("Prompt ({} tokens): {prompt}", token_ids.len());
@@ -83,11 +99,21 @@ fn main() {
// (oracle) next token. Removes free-running compounding so it isolates
// whether per-position logits agree with the llama.cpp trajectory.
if let Some(forced) = get_arg::<String>(&args, "--forced") {
let forced_ids: Vec<u32> = forced.split(',').filter_map(|s| s.trim().parse().ok()).collect();
let forced_ids: Vec<u32> = forced
.split(',')
.filter_map(|s| s.trim().parse().ok())
.collect();
let mut seq = token_ids.clone();
seq.extend_from_slice(&forced_ids);
// Workers must run the same prefill in lockstep (TP AllReduces match up).
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Prefill { tokens: seq.clone(), slot });
broadcast_cmd(
&worker_txs,
&worker_handles,
WorkerCmd::Prefill {
tokens: seq.clone(),
slot,
},
);
let logits = model.forward_prefill_paged(&seq, slot, &mut cache);
wait_workers(&worker_handles);
let logits_cpu = logits.to_device(Device::Cpu);
@@ -99,19 +125,31 @@ fn main() {
// position i predicts seq[i+1]; we check the forced region
for i in (plen - 1)..(seq.len() - 1) {
let row = &data[i * vocab..(i + 1) * vocab];
let argmax = row.iter().enumerate()
let argmax = row
.iter()
.enumerate()
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
.map(|(j, _)| j as u32).unwrap();
.map(|(j, _)| j as u32)
.unwrap();
let expected = seq[i + 1];
let ok = argmax == expected;
if ok { matches += 1; }
if ok {
matches += 1;
}
total += 1;
eprintln!("pos {i}: xserv_argmax={argmax} oracle={expected} {}", if ok {"OK"} else {"DIFF"});
eprintln!(
"pos {i}: xserv_argmax={argmax} oracle={expected} {}",
if ok { "OK" } else { "DIFF" }
);
}
eprintln!("\nTeacher-forced top-1 agreement: {matches}/{total} = {:.1}%",
100.0 * matches as f64 / total as f64);
eprintln!(
"\nTeacher-forced top-1 agreement: {matches}/{total} = {:.1}%",
100.0 * matches as f64 / total as f64
);
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Shutdown);
for (h, _) in worker_handles { h.join().unwrap(); }
for (h, _) in worker_handles {
h.join().unwrap();
}
return;
}
@@ -120,8 +158,18 @@ fn main() {
// per-position top-1 agreement bucketed by position. Localizes long-context
// decode degradation (which prefill teacher-forcing cannot see).
if let Some(forced) = get_arg::<String>(&args, "--forced-decode") {
let forced_ids: Vec<u32> = forced.split(',').filter_map(|s| s.trim().parse().ok()).collect();
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Prefill { tokens: token_ids.clone(), slot });
let forced_ids: Vec<u32> = forced
.split(',')
.filter_map(|s| s.trim().parse().ok())
.collect();
broadcast_cmd(
&worker_txs,
&worker_handles,
WorkerCmd::Prefill {
tokens: token_ids.clone(),
slot,
},
);
let logits = model.forward_prefill_paged(&token_ids, slot, &mut cache);
wait_workers(&worker_handles);
let mut pred = sample_greedy_last(&logits); // prediction for forced[0]
@@ -133,34 +181,55 @@ fn main() {
matches += ok as usize;
total += 1;
let b = i / bucket;
if buckets.len() <= b { buckets.push((0, 0)); }
if buckets.len() <= b {
buckets.push((0, 0));
}
buckets[b].0 += ok as usize;
buckets[b].1 += 1;
// Teacher-force: feed the oracle token through the decode path.
let pos = cache.seq_len(slot);
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Decode {
tokens: vec![f], positions: vec![pos], slots: vec![slot],
});
broadcast_cmd(
&worker_txs,
&worker_handles,
WorkerCmd::Decode {
tokens: vec![f],
positions: vec![pos],
slots: vec![slot],
},
);
let logits = model.forward_decode_paged(&[f], &[pos], &[slot], &mut cache);
wait_workers(&worker_handles);
pred = sample_greedy_last(&logits);
}
eprintln!("Teacher-forced DECODE agreement: {matches}/{total} = {:.1}%",
100.0 * matches as f64 / total as f64);
eprintln!(
"Teacher-forced DECODE agreement: {matches}/{total} = {:.1}%",
100.0 * matches as f64 / total as f64
);
for (b, (m, t)) in buckets.iter().enumerate() {
eprintln!(" pos[{:>4}..{:<4}]: {m:>3}/{t:<3} = {:.0}%",
b * bucket, b * bucket + t, 100.0 * (*m as f64) / (*t as f64));
eprintln!(
" pos[{:>4}..{:<4}]: {m:>3}/{t:<3} = {:.0}%",
b * bucket,
b * bucket + t,
100.0 * (*m as f64) / (*t as f64)
);
}
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Shutdown);
for (h, _) in worker_handles { h.join().unwrap(); }
for (h, _) in worker_handles {
h.join().unwrap();
}
return;
}
// Prefill
let t0 = Instant::now();
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Prefill {
tokens: token_ids.clone(), slot,
});
broadcast_cmd(
&worker_txs,
&worker_handles,
WorkerCmd::Prefill {
tokens: token_ids.clone(),
slot,
},
);
let logits = model.forward_prefill_paged(&token_ids, slot, &mut cache);
wait_workers(&worker_handles);
let ttft = t0.elapsed();
@@ -172,18 +241,27 @@ fn main() {
print!("{prompt}");
// Decode
let mut decoder = GraphedGptOssDecoder::new();
let decode_start = Instant::now();
for _ in 1..max_tokens {
let text = tokenizer.decode(&[next]);
print!("{text}");
if tokenizer.eos_token_id() == Some(next) { break; }
if tokenizer.eos_token_id() == Some(next) {
break;
}
let pos = cache.seq_len(slot);
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Decode {
tokens: vec![next], positions: vec![pos], slots: vec![slot],
});
let logits = model.forward_decode_paged(&[next], &[pos], &[slot], &mut cache);
broadcast_cmd(
&worker_txs,
&worker_handles,
WorkerCmd::Decode {
tokens: vec![next],
positions: vec![pos],
slots: vec![slot],
},
);
let logits = decoder.decode(&model, &[next], &[pos], &[slot], &mut cache);
wait_workers(&worker_handles);
next = sample_greedy_last(&logits);
@@ -195,13 +273,20 @@ fn main() {
let gen_tokens = output_tokens.len();
let full_text = tokenizer.decode(&output_tokens);
eprintln!("\nGenerated text: {full_text}");
eprintln!("Token IDs: {:?}", &output_tokens[..output_tokens.len().min(20)]);
eprintln!(
"Token IDs: {:?}",
&output_tokens[..output_tokens.len().min(20)]
);
let tpot = if gen_tokens > 1 {
decode_elapsed.as_secs_f64() * 1000.0 / (gen_tokens - 1) as f64
} else { 0.0 };
} else {
0.0
};
let tok_s = if gen_tokens > 1 {
(gen_tokens - 1) as f64 / decode_elapsed.as_secs_f64()
} else { 0.0 };
} else {
0.0
};
eprintln!("\n--- Performance ---");
eprintln!("Generated: {} tokens", gen_tokens);
@@ -221,8 +306,15 @@ fn main() {
#[derive(Clone)]
enum WorkerCmd {
Register(usize),
Prefill { tokens: Vec<u32>, slot: usize },
Decode { tokens: Vec<u32>, positions: Vec<usize>, slots: Vec<usize> },
Prefill {
tokens: Vec<u32>,
slot: usize,
},
Decode {
tokens: Vec<u32>,
positions: Vec<usize>,
slots: Vec<usize>,
},
Shutdown,
}
@@ -240,16 +332,25 @@ fn worker_loop(
let tp = Arc::new(TpContext::init(rank, world, uid, rank as u32));
eprintln!("[rank {rank}] Loading weights...");
let weights = loader::load_model_dir(&model_dir, Device::Cpu);
let model = GptOss::from_weights_tp(config.clone(), weights, rank, world, rank as u32, Some(tp));
let model =
GptOss::from_weights_tp(config.clone(), weights, rank, world, rank as u32, Some(tp));
let local_kv = config.num_kv_heads() / world;
let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
let total_blocks = max_blocks_per_seq + 64;
let mut cache = PagedKVCache::new_tp(
&config, local_kv, total_blocks, 0, 4, max_blocks_per_seq, DType::BF16, rank as u32,
&config,
local_kv,
total_blocks,
0,
4,
max_blocks_per_seq,
DType::BF16,
rank as u32,
);
eprintln!("[rank {rank}] Ready.");
ack_tx.send(()).unwrap();
let mut decoder = GraphedGptOssDecoder::new();
while let Ok(cmd) = rx.recv() {
match cmd {
WorkerCmd::Register(slot) => {
@@ -258,8 +359,12 @@ fn worker_loop(
WorkerCmd::Prefill { tokens, slot } => {
let _ = model.forward_prefill_paged(&tokens, slot, &mut cache);
}
WorkerCmd::Decode { tokens, positions, slots } => {
let _ = model.forward_decode_paged(&tokens, &positions, &slots, &mut cache);
WorkerCmd::Decode {
tokens,
positions,
slots,
} => {
let _ = decoder.decode(&model, &tokens, &positions, &slots, &mut cache);
}
WorkerCmd::Shutdown => break,
}
@@ -286,20 +391,26 @@ fn wait_workers(handles: &[(std::thread::JoinHandle<()>, std::sync::mpsc::Receiv
fn sample_greedy_last(logits: &xserv_tensor::Tensor) -> u32 {
use half::bf16;
assert_eq!(logits.ndim(), 2);
// GPU argmax fast path (4-byte D2H instead of the full logits row).
if logits.dtype() == xserv_tensor::DType::BF16 && logits.is_contiguous() {
let ids = xserv_kernels::argmax_bf16_to_host(logits);
return *ids.last().unwrap();
}
let logits_cpu = logits.to_device(Device::Cpu);
let vocab_size = logits.shape()[1];
let seq_len = logits.shape()[0];
let data = logits_cpu.as_slice::<bf16>();
let last = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size];
last.iter().enumerate()
last.iter()
.enumerate()
.max_by(|a, b| {
let af = a.1.to_f32();
let bf = b.1.to_f32();
af.partial_cmp(&bf).unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i as u32).unwrap()
.map(|(i, _)| i as u32)
.unwrap()
}
fn get_arg<T: std::str::FromStr>(args: &[String], flag: &str) -> Option<T> {

View File

@@ -1,7 +1,7 @@
use std::path::PathBuf;
use std::time::Instant;
use xserv_model::gpt2::{sample_greedy, KVCache};
use xserv_model::{loader, GPT2, ModelConfig};
use xserv_model::gpt2::{KVCache, sample_greedy};
use xserv_model::{GPT2, ModelConfig, loader};
use xserv_tensor::Device;
use xserv_tokenizer::Tokenizer;
@@ -104,9 +104,15 @@ fn main() {
let tbt_us = if !token_times_us.is_empty() {
token_times_us.iter().sum::<u128>() / token_times_us.len() as u128
} else { 0 };
} else {
0
};
let total_gen_us: u128 = ttft_us + token_times_us.iter().sum::<u128>();
let tpot_us = if num_generated > 0 { total_gen_us / num_generated as u128 } else { 0 };
let tpot_us = if num_generated > 0 {
total_gen_us / num_generated as u128
} else {
0
};
let gen_text_escaped = generated_text
.replace('\\', "\\\\")
@@ -124,11 +130,16 @@ fn main() {
print!("\"ttft_us\": {ttft_us}, ");
print!("\"tbt_us\": {tbt_us}, ");
print!("\"tpot_us\": {tpot_us}}}");
if i < prompts.len() - 1 { println!(","); } else { println!(); }
if i < prompts.len() - 1 {
println!(",");
} else {
println!();
}
eprintln!(
"[{}/{}] input={input_len}tok gen={num_generated}tok ttft={:.1}ms tbt={:.1}ms | {}",
i + 1, prompts.len(),
i + 1,
prompts.len(),
ttft_us as f64 / 1000.0,
tbt_us as f64 / 1000.0,
&generated_text.replace('\n', " ")[..generated_text.len().min(60)]
@@ -138,12 +149,18 @@ fn main() {
}
fn generate_with_cache(
model: &GPT2, config: &ModelConfig, tokenizer: &Tokenizer,
input_ids: &[u32], gen_tokens: usize,
model: &GPT2,
config: &ModelConfig,
tokenizer: &Tokenizer,
input_ids: &[u32],
gen_tokens: usize,
) -> (Vec<u32>, u128, Vec<u128>) {
let mut cache = KVCache::new(
config.num_layers(), config.num_heads(), config.head_dim(),
xserv_tensor::DType::F32, Device::Cuda(0),
config.num_layers(),
config.num_heads(),
config.head_dim(),
xserv_tensor::DType::F32,
Device::Cuda(0),
);
// Prefill
@@ -163,15 +180,19 @@ fn generate_with_cache(
let next = sample_greedy(&logits);
token_times.push(t_start.elapsed().as_micros());
generated.push(next);
if tokenizer.eos_token_id() == Some(next) { break; }
if tokenizer.eos_token_id() == Some(next) {
break;
}
}
(generated, ttft_us, token_times)
}
fn generate_no_cache(
model: &GPT2, tokenizer: &Tokenizer,
input_ids: &[u32], gen_tokens: usize,
model: &GPT2,
tokenizer: &Tokenizer,
input_ids: &[u32],
gen_tokens: usize,
) -> (Vec<u32>, u128, Vec<u128>) {
let mut all_ids = input_ids.to_vec();
@@ -191,7 +212,9 @@ fn generate_no_cache(
token_times.push(t_start.elapsed().as_micros());
all_ids.push(next);
generated.push(next);
if tokenizer.eos_token_id() == Some(next) { break; }
if tokenizer.eos_token_id() == Some(next) {
break;
}
}
(generated, ttft_us, token_times)

View File

@@ -1,7 +1,7 @@
use std::path::PathBuf;
use std::time::Instant;
use xserv_model::qwen3::sample_greedy;
use xserv_model::{loader, DecodeGraphState, GpuKVCache, ModelConfig, Qwen3};
use xserv_model::{DecodeGraphState, GpuKVCache, ModelConfig, Qwen3, loader};
use xserv_tensor::{DType, Device};
use xserv_tokenizer::Tokenizer;
@@ -139,18 +139,35 @@ fn main() {
} else {
// Replay captured graphs
let pos = cache.seq_len() as u32;
graph.execute(last, pos, &mut cache, &layer_ptrs, embed, config.vocab_size as i32, config.hidden() as i32);
graph.execute(
last,
pos,
&mut cache,
&layer_ptrs,
embed,
config.vocab_size as i32,
config.hidden() as i32,
);
cache.advance_seq_len(1);
// Read logits from graph buffer
let vocab_size = config.vocab_size;
let mut logits_bytes = vec![0u8; vocab_size * 2];
graph.logits_buffer().copy_to_host(&mut logits_bytes).unwrap();
graph
.logits_buffer()
.copy_to_host(&mut logits_bytes)
.unwrap();
let logits_data: &[half::bf16] = unsafe {
std::slice::from_raw_parts(logits_bytes.as_ptr() as *const half::bf16, vocab_size)
std::slice::from_raw_parts(
logits_bytes.as_ptr() as *const half::bf16,
vocab_size,
)
};
logits_data.iter().enumerate()
logits_data
.iter()
.enumerate()
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
.map(|(idx, _)| idx as u32).unwrap()
.map(|(idx, _)| idx as u32)
.unwrap()
}
} else {
let logits = model.forward_gpu_cache(&[last], &mut cache);
@@ -159,16 +176,24 @@ fn main() {
token_times.push(t_start.elapsed().as_micros());
generated.push(next);
if tokenizer.eos_token_id() == Some(next) { break; }
if tokenizer.eos_token_id() == Some(next) {
break;
}
}
let num_generated = generated.len();
let generated_text = tokenizer.decode(&generated);
let tbt_us = if !token_times.is_empty() {
token_times.iter().sum::<u128>() / token_times.len() as u128
} else { 0 };
} else {
0
};
let total_gen_us: u128 = ttft_us + token_times.iter().sum::<u128>();
let tpot_us = if num_generated > 0 { total_gen_us / num_generated as u128 } else { 0 };
let tpot_us = if num_generated > 0 {
total_gen_us / num_generated as u128
} else {
0
};
let gen_text_escaped = generated_text
.replace('\\', "\\\\")
@@ -186,13 +211,18 @@ fn main() {
print!("\"ttft_us\": {ttft_us}, ");
print!("\"tbt_us\": {tbt_us}, ");
print!("\"tpot_us\": {tpot_us}}}");
if i < prompts.len() - 1 { println!(","); } else { println!(); }
if i < prompts.len() - 1 {
println!(",");
} else {
println!();
}
let display_text = generated_text.replace('\n', " ");
let truncated: String = display_text.chars().take(60).collect();
eprintln!(
"[{}/{}] input={input_len}tok gen={num_generated}tok ttft={:.1}ms tbt={:.1}ms | {}",
i + 1, prompts.len(),
i + 1,
prompts.len(),
ttft_us as f64 / 1000.0,
tbt_us as f64 / 1000.0,
truncated

View File

@@ -18,7 +18,7 @@ use std::thread;
use std::time::Instant;
use xserv_model::qwen3::sample_greedy;
use xserv_model::{loader, ModelConfig, PagedKVCache, Qwen3, BLOCK_SIZE};
use xserv_model::{BLOCK_SIZE, ModelConfig, PagedKVCache, Qwen3, loader};
use xserv_tensor::{DType, Device};
use xserv_tokenizer::Tokenizer;
@@ -35,8 +35,13 @@ fn main() {
std::process::exit(1);
}
let model_dir = PathBuf::from(&args[1]);
let world: usize = arg(&args, "--tp").and_then(|s| s.parse().ok()).unwrap_or(1).max(1);
let gen_tokens: usize = arg(&args, "--gen-tokens").and_then(|s| s.parse().ok()).unwrap_or(64);
let world: usize = arg(&args, "--tp")
.and_then(|s| s.parse().ok())
.unwrap_or(1)
.max(1);
let gen_tokens: usize = arg(&args, "--gen-tokens")
.and_then(|s| s.parse().ok())
.unwrap_or(64);
let devices: Vec<u32> = match arg(&args, "--devices") {
Some(s) => s.split(',').filter_map(|d| d.trim().parse().ok()).collect(),
None => (0..world as u32).collect(),
@@ -67,7 +72,11 @@ fn main() {
// Tensors are not Send (their Storage holds a raw GPU pointer), so each rank
// thread loads its own CPU copy of the weights and shards in-thread. Loading
// is not part of the timed region.
let id = if world > 1 { Some(xserv_distributed::get_unique_id()) } else { None };
let id = if world > 1 {
Some(xserv_distributed::get_unique_id())
} else {
None
};
let handles: Vec<_> = (0..world)
.map(|rank| {
@@ -76,7 +85,9 @@ fn main() {
let prompt_ids = prompt_ids.clone();
let device = devices[rank];
thread::spawn(move || {
run_rank(rank, world, device, id, config, model_dir, prompt_ids, gen_tokens, eos)
run_rank(
rank, world, device, id, config, model_dir, prompt_ids, gen_tokens, eos,
)
})
})
.collect();
@@ -91,7 +102,10 @@ fn main() {
let results = rank0.expect("rank 0 produced no results");
println!("\n=== TP={world} (devices {devices:?}) — Qwen3 E2E benchmark ===");
println!("{:<45} {:>10} {:>12} {:>8}", "prompt", "TTFT(ms)", "decode tok/s", "gen");
println!(
"{:<45} {:>10} {:>12} {:>8}",
"prompt", "TTFT(ms)", "decode tok/s", "gen"
);
let mut tps_sum = 0.0;
for (i, r) in results.iter().enumerate() {
let text = tokenizer.decode(&r.gen_ids).replace('\n', " ");
@@ -99,16 +113,29 @@ fn main() {
let p: String = prompts[i].chars().take(43).collect();
println!(
"{:<45} {:>10.1} {:>12.1} {:>8} | {}",
p, r.ttft_ms, r.decode_tok_s, r.gen_ids.len(), short
p,
r.ttft_ms,
r.decode_tok_s,
r.gen_ids.len(),
short
);
tps_sum += r.decode_tok_s;
}
println!("--- mean decode throughput: {:.1} tok/s ---", tps_sum / results.len() as f64);
println!(
"--- mean decode throughput: {:.1} tok/s ---",
tps_sum / results.len() as f64
);
// Machine-readable line for cross-TP correctness diffing (rank 0 token ids).
let all_ids: Vec<String> = results
.iter()
.map(|r| r.gen_ids.iter().map(|x| x.to_string()).collect::<Vec<_>>().join(","))
.map(|r| {
r.gen_ids
.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join(",")
})
.collect();
println!("CORRECTNESS_IDS tp={world} {}", all_ids.join(" | "));
}
@@ -126,7 +153,12 @@ fn run_rank(
) -> Option<Vec<PromptResult>> {
// Bind this thread to its GPU and set up the TP communicator.
let tp = if world > 1 {
Some(Arc::new(xserv_distributed::TpContext::init(rank, world, id.unwrap(), device)))
Some(Arc::new(xserv_distributed::TpContext::init(
rank,
world,
id.unwrap(),
device,
)))
} else {
xserv_cuda::device::set_device(device).unwrap();
None
@@ -142,7 +174,14 @@ fn run_rank(
let max_blocks_per_seq = max_seq.div_ceil(BLOCK_SIZE);
let total_blocks = max_blocks_per_seq + 8;
let mut cache = PagedKVCache::new_tp(
&config, local_kv, total_blocks, 0, 1, max_blocks_per_seq, DType::BF16, device,
&config,
local_kv,
total_blocks,
0,
1,
max_blocks_per_seq,
DType::BF16,
device,
);
// Warmup (init kernels / allocator / NCCL channels) — not timed.
@@ -177,12 +216,20 @@ fn run_rank(
steps += 1;
}
let decode_s = t1.elapsed().as_secs_f64();
let decode_tok_s = if steps > 0 && decode_s > 0.0 { steps as f64 / decode_s } else { 0.0 };
let decode_tok_s = if steps > 0 && decode_s > 0.0 {
steps as f64 / decode_s
} else {
0.0
};
cache.free_sequence(0);
if rank == 0 {
out.push(PromptResult { gen_ids: generated, ttft_ms, decode_tok_s });
out.push(PromptResult {
gen_ids: generated,
ttft_ms,
decode_tok_s,
});
}
}
@@ -190,5 +237,8 @@ fn run_rank(
}
fn arg<'a>(args: &'a [String], flag: &str) -> Option<&'a str> {
args.iter().position(|a| a == flag).and_then(|i| args.get(i + 1)).map(|s| s.as_str())
args.iter()
.position(|a| a == flag)
.and_then(|i| args.get(i + 1))
.map(|s| s.as_str())
}

View File

@@ -1,8 +1,8 @@
use half::bf16;
use std::path::PathBuf;
use xserv_model::{loader, KVCache, ModelConfig, Qwen3};
use xserv_model::{KVCache, ModelConfig, Qwen3, loader};
use xserv_tensor::{DType, Device};
use xserv_tokenizer::Tokenizer;
use half::bf16;
fn main() {
let args: Vec<String> = std::env::args().collect();
@@ -20,8 +20,11 @@ fn main() {
eprintln!("Token IDs: {token_ids:?}");
let mut cache = KVCache::new(
config.num_layers(), config.num_kv_heads(), config.head_dim(),
DType::BF16, Device::Cuda(0),
config.num_layers(),
config.num_kv_heads(),
config.head_dim(),
DType::BF16,
Device::Cuda(0),
);
let logits = model.forward_with_cache(&token_ids, &mut cache);
let logits_cpu = logits.to_device(Device::Cpu);
@@ -31,7 +34,9 @@ fn main() {
// Print top-20 logits for the last position
let last_row = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size];
let mut indexed: Vec<(usize, f32)> = last_row.iter().enumerate()
let mut indexed: Vec<(usize, f32)> = last_row
.iter()
.enumerate()
.map(|(i, v)| (i, v.to_f32()))
.collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());

View File

@@ -1,10 +1,13 @@
use std::io::{self, IsTerminal, Read, Write};
use std::path::PathBuf;
use std::sync::{mpsc, Arc};
use std::sync::{Arc, mpsc};
use std::thread;
use xserv_model::{loader, sample, sample_greedy_penalized, GptOss, ModelConfig, PagedKVCache, Qwen3, SamplingParams, BLOCK_SIZE};
use xserv_model::{
BLOCK_SIZE, GptOss, GraphedGptOssDecoder, ModelConfig, PagedKVCache, Qwen3, SamplingParams,
loader, sample, sample_greedy_penalized,
};
use xserv_tensor::{DType, Device};
use xserv_tokenizer::Tokenizer;
@@ -14,13 +17,24 @@ enum ChatModel {
}
impl ChatModel {
fn forward_prefill_paged(&self, tokens: &[u32], slot: usize, cache: &mut PagedKVCache) -> xserv_tensor::Tensor {
fn forward_prefill_paged(
&self,
tokens: &[u32],
slot: usize,
cache: &mut PagedKVCache,
) -> xserv_tensor::Tensor {
match self {
ChatModel::Qwen3(m) => m.forward_prefill_paged(tokens, slot, cache),
ChatModel::GptOss(m) => m.forward_prefill_paged(tokens, slot, cache),
}
}
fn forward_decode_paged(&self, tokens: &[u32], positions: &[usize], slots: &[usize], cache: &mut PagedKVCache) -> xserv_tensor::Tensor {
fn forward_decode_paged(
&self,
tokens: &[u32],
positions: &[usize],
slots: &[usize],
cache: &mut PagedKVCache,
) -> xserv_tensor::Tensor {
match self {
ChatModel::Qwen3(m) => m.forward_decode_paged(tokens, positions, slots, cache),
ChatModel::GptOss(m) => m.forward_decode_paged(tokens, positions, slots, cache),
@@ -33,8 +47,15 @@ impl ChatModel {
enum TpCommand {
Register(usize),
Free(usize),
Prefill { tokens: Vec<u32>, slot: usize },
Decode { tokens: Vec<u32>, positions: Vec<usize>, slots: Vec<usize> },
Prefill {
tokens: Vec<u32>,
slot: usize,
},
Decode {
tokens: Vec<u32>,
positions: Vec<usize>,
slots: Vec<usize>,
},
}
struct TpHandle {
@@ -56,7 +77,8 @@ impl TpHandle {
}
fn tp_worker_loop(
rank: usize, world: usize,
rank: usize,
world: usize,
id: xserv_distributed::UniqueId,
model_dir: std::path::PathBuf,
config: ModelConfig,
@@ -64,28 +86,68 @@ fn tp_worker_loop(
cmd_rx: mpsc::Receiver<TpCommand>,
ack_tx: mpsc::Sender<()>,
) {
let tp = Arc::new(xserv_distributed::TpContext::init(rank, world, id, rank as u32));
let tp = Arc::new(xserv_distributed::TpContext::init(
rank,
world,
id,
rank as u32,
));
let weights = loader::load_model_dir(&model_dir, Device::Cpu);
let model = if config.is_moe() {
ChatModel::GptOss(GptOss::from_weights_tp(config.clone(), weights, rank, world, rank as u32, Some(tp)))
ChatModel::GptOss(GptOss::from_weights_tp(
config.clone(),
weights,
rank,
world,
rank as u32,
Some(tp),
))
} else {
ChatModel::Qwen3(Qwen3::from_weights_tp(config.clone(), weights, rank, world, rank as u32, Some(tp)))
ChatModel::Qwen3(Qwen3::from_weights_tp(
config.clone(),
weights,
rank,
world,
rank as u32,
Some(tp),
))
};
let local_kv = config.num_kv_heads() / world;
let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
let total_blocks = max_blocks_per_seq + 8;
let mut cache = PagedKVCache::new_tp(
&config, local_kv, total_blocks, 0, 1, max_blocks_per_seq, DType::BF16, rank as u32,
&config,
local_kv,
total_blocks,
0,
1,
max_blocks_per_seq,
DType::BF16,
rank as u32,
);
let mut decoder = GraphedGptOssDecoder::new();
while let Ok(cmd) = cmd_rx.recv() {
match cmd {
TpCommand::Register(slot) => { let _ = cache.register_sequence(slot); }
TpCommand::Register(slot) => {
let _ = cache.register_sequence(slot);
}
TpCommand::Free(slot) => cache.free_sequence(slot),
TpCommand::Prefill { tokens, slot } => {
let _ = model.forward_prefill_paged(&tokens, slot, &mut cache);
}
TpCommand::Decode { tokens, positions, slots } => {
let _ = model.forward_decode_paged(&tokens, &positions, &slots, &mut cache);
TpCommand::Decode {
tokens,
positions,
slots,
} => {
let _ = chat_decode(
&model,
&mut decoder,
&tokens,
&positions,
&slots,
&mut cache,
);
}
}
let _ = ack_tx.send(());
@@ -220,7 +282,13 @@ fn read_line_edited(prompt: &str) -> Line {
}
b => {
// UTF-8 multi-byte: read the continuation bytes for this char.
let extra = if b >= 0xF0 { 3 } else if b >= 0xE0 { 2 } else { 1 };
let extra = if b >= 0xF0 {
3
} else if b >= 0xE0 {
2
} else {
1
};
let mut bytes = vec![b];
let mut cont = [0u8; 1];
let mut ok = true;
@@ -274,7 +342,8 @@ fn main() {
if world > 1 {
assert!(
config.num_kv_heads() % world == 0,
"num_kv_heads {} not divisible by tp {world}", config.num_kv_heads()
"num_kv_heads {} not divisible by tp {world}",
config.num_kv_heads()
);
}
@@ -289,7 +358,16 @@ fn main() {
let model_dir = opts.model_dir.clone();
let config = config.clone();
thread::spawn(move || {
tp_worker_loop(rank, world, id, model_dir, config, max_seq_len, ctx_rx, ack_tx);
tp_worker_loop(
rank,
world,
id,
model_dir,
config,
max_seq_len,
ctx_rx,
ack_tx,
);
});
}
eprintln!("Loading weights (tp={world})...");
@@ -297,14 +375,37 @@ fn main() {
let weights = loader::load_model_dir(&opts.model_dir, Device::Cpu);
eprintln!("Loaded {} tensors", weights.len());
let m = if is_moe {
ChatModel::GptOss(GptOss::from_weights_tp(config.clone(), weights, 0, world, 0, Some(tp)))
ChatModel::GptOss(GptOss::from_weights_tp(
config.clone(),
weights,
0,
world,
0,
Some(tp),
))
} else {
ChatModel::Qwen3(Qwen3::from_weights_tp(config.clone(), weights, 0, world, 0, Some(tp)))
ChatModel::Qwen3(Qwen3::from_weights_tp(
config.clone(),
weights,
0,
world,
0,
Some(tp),
))
};
let local_kv = config.num_kv_heads() / world;
let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
let total_blocks = max_blocks_per_seq + 8;
let c = PagedKVCache::new_tp(&config, local_kv, total_blocks, 0, 1, max_blocks_per_seq, DType::BF16, 0);
let c = PagedKVCache::new_tp(
&config,
local_kv,
total_blocks,
0,
1,
max_blocks_per_seq,
DType::BF16,
0,
);
let h = TpHandle { cmd_txs, ack_rx };
(m, c, Some(h))
} else {
@@ -321,7 +422,11 @@ fn main() {
};
let tokenizer = Tokenizer::from_file(&opts.model_dir.join("tokenizer.json"));
if let Some(h) = &tp_handle { h.send(TpCommand::Register(SLOT)); h.wait(); }
let mut decoder = GraphedGptOssDecoder::new();
if let Some(h) = &tp_handle {
h.send(TpCommand::Register(SLOT));
h.wait();
}
cache.register_sequence(SLOT).expect("register chat slot");
let use_color = opts.color && io::stdout().is_terminal();
@@ -363,11 +468,8 @@ fn main() {
if is_moe {
// Harmony multi-turn: re-render the whole conversation (prior
// analysis dropped) and re-prefill into a freshly cleared slot.
let prompt = build_conversation_gpt_oss(
opts.system_prompt.as_deref(),
&moe_history,
input,
);
let prompt =
build_conversation_gpt_oss(opts.system_prompt.as_deref(), &moe_history, input);
let prompt_tokens = tokenizer.encode(&prompt);
if prompt_tokens.is_empty() {
continue;
@@ -384,8 +486,17 @@ fn main() {
print!("assistant> ");
io::stdout().flush().unwrap();
let (_finish, answer) = generate_with_paged_cache(
&model, &mut cache, &tokenizer, &prompt_tokens, &opts.sampling,
max_new_tokens, use_color, &tp_handle, is_moe, opts.enable_thinking,
&model,
&mut decoder,
&mut cache,
&tokenizer,
&prompt_tokens,
&opts.sampling,
max_new_tokens,
use_color,
&tp_handle,
is_moe,
opts.enable_thinking,
);
moe_history.push((input.to_string(), answer));
println!();
@@ -421,6 +532,7 @@ fn main() {
io::stdout().flush().unwrap();
let (finish, _answer) = generate_with_paged_cache(
&model,
&mut decoder,
&mut cache,
&tokenizer,
&prompt_tokens,
@@ -433,10 +545,24 @@ fn main() {
);
match finish {
Finish::Stop { token_id } => {
append_after_stop(&model, &mut cache, &tokenizer, max_seq_len, token_id, &tp_handle);
append_after_stop(
&model,
&mut cache,
&tokenizer,
max_seq_len,
token_id,
&tp_handle,
);
}
Finish::Length => {
append_text_to_cache(&model, &mut cache, &tokenizer, max_seq_len, "<|im_end|>\n", &tp_handle);
append_text_to_cache(
&model,
&mut cache,
&tokenizer,
max_seq_len,
"<|im_end|>\n",
&tp_handle,
);
}
}
println!();
@@ -445,9 +571,15 @@ fn main() {
/// Free and re-register the single chat KV slot (clears all cached context).
fn reset_slot(cache: &mut PagedKVCache, tp: &Option<TpHandle>) {
if let Some(h) = tp { h.send(TpCommand::Free(SLOT)); h.wait(); }
if let Some(h) = tp {
h.send(TpCommand::Free(SLOT));
h.wait();
}
cache.free_sequence(SLOT);
if let Some(h) = tp { h.send(TpCommand::Register(SLOT)); h.wait(); }
if let Some(h) = tp {
h.send(TpCommand::Register(SLOT));
h.wait();
}
cache.register_sequence(SLOT).expect("register chat slot");
}
@@ -585,7 +717,15 @@ fn new_paged_cache(config: &ModelConfig, max_seq_len: usize) -> PagedKVCache {
let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
let total_blocks = (max_blocks_per_seq + 1).max(2);
// Single-slot interactive CLI: no swap pool (cpu_total_blocks = 0).
PagedKVCache::new(config, total_blocks, 0, 1, max_blocks_per_seq, DType::BF16, 0)
PagedKVCache::new(
config,
total_blocks,
0,
1,
max_blocks_per_seq,
DType::BF16,
0,
)
}
fn build_turn_prompt(
@@ -665,7 +805,10 @@ fn build_conversation_gpt_oss(
/// civil-calendar conversion (same algorithm the server uses for strftime_now).
fn today_ymd() -> String {
use std::time::{SystemTime, UNIX_EPOCH};
let secs = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs();
let secs = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let z = (secs / 86400) as i64 + 719468;
let era = (if z >= 0 { z } else { z - 146096 }) / 146097;
let doe = z - era * 146097;
@@ -679,8 +822,23 @@ fn today_ymd() -> String {
format!("{y:04}-{m:02}-{d:02}")
}
fn chat_decode(
model: &ChatModel,
decoder: &mut GraphedGptOssDecoder,
tokens: &[u32],
positions: &[usize],
slots: &[usize],
cache: &mut PagedKVCache,
) -> xserv_tensor::Tensor {
match model {
ChatModel::GptOss(m) => decoder.decode(m, tokens, positions, slots, cache),
ChatModel::Qwen3(_) => model.forward_decode_paged(tokens, positions, slots, cache),
}
}
fn generate_with_paged_cache(
model: &ChatModel,
decoder: &mut GraphedGptOssDecoder,
cache: &mut PagedKVCache,
tokenizer: &Tokenizer,
prompt_tokens: &[u32],
@@ -691,12 +849,32 @@ fn generate_with_paged_cache(
is_moe: bool,
enable_thinking: bool,
) -> (Finish, String) {
let harmony_end_id = if is_moe { tokenizer.special_token_id("<|end|>") } else { None };
let harmony_channel_id = if is_moe { tokenizer.special_token_id("<|channel|>") } else { None };
let harmony_message_id = if is_moe { tokenizer.special_token_id("<|message|>") } else { None };
let harmony_end_id = if is_moe {
tokenizer.special_token_id("<|end|>")
} else {
None
};
let harmony_channel_id = if is_moe {
tokenizer.special_token_id("<|channel|>")
} else {
None
};
let harmony_message_id = if is_moe {
tokenizer.special_token_id("<|message|>")
} else {
None
};
let harmony_special: Vec<u32> = if is_moe {
["<|channel|>", "<|start|>", "<|end|>", "<|message|>", "<|return|>"]
.iter().filter_map(|s| tokenizer.special_token_id(s)).collect()
[
"<|channel|>",
"<|start|>",
"<|end|>",
"<|message|>",
"<|return|>",
]
.iter()
.filter_map(|s| tokenizer.special_token_id(s))
.collect()
} else {
Vec::new()
};
@@ -704,18 +882,29 @@ fn generate_with_paged_cache(
// "analysis" channel is rendered as thinking (gray). After <|channel|>
// we read the channel name tokens until <|message|>.
#[derive(PartialEq, Clone, Copy)]
enum HarmonyState { Normal, ReadingChannel, InAnalysis, InFinal }
let mut hstate = if is_moe { HarmonyState::InFinal } else { HarmonyState::Normal };
enum HarmonyState {
Normal,
ReadingChannel,
InAnalysis,
InFinal,
}
let mut hstate = if is_moe {
HarmonyState::InFinal
} else {
HarmonyState::Normal
};
// Off by default. A repetition penalty over a harmony stream penalizes the
// control tokens (<|channel|>, <|message|>, <|start|>) that MUST repeat to
// open the final channel — so a non-1.0 default makes gpt-oss stop right
// after the analysis block, before emitting any answer. Opt in via the env
// var if you want it for plain (non-harmony) generation.
let rep_penalty: f32 = std::env::var("XSERV_REP_PENALTY").ok()
let rep_penalty: f32 = std::env::var("XSERV_REP_PENALTY")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(1.0);
let rep_window: usize = std::env::var("XSERV_REP_WINDOW").ok()
let rep_window: usize = std::env::var("XSERV_REP_WINDOW")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(512);
let mut history: Vec<u32> = Vec::new();
@@ -729,9 +918,16 @@ fn generate_with_paged_cache(
}
};
if let Some(h) = tp { h.send(TpCommand::Prefill { tokens: prompt_tokens.to_vec(), slot: SLOT }); }
if let Some(h) = tp {
h.send(TpCommand::Prefill {
tokens: prompt_tokens.to_vec(),
slot: SLOT,
});
}
let logits = model.forward_prefill_paged(prompt_tokens, SLOT, cache);
if let Some(h) = tp { h.wait(); }
if let Some(h) = tp {
h.wait();
}
let mut next = pick(&logits, sampling, &history);
let mut decode_buffer = Vec::new();
let mut in_thinking = false;
@@ -744,9 +940,17 @@ fn generate_with_paged_cache(
for _ in 0..max_tokens {
let position = cache.seq_len(SLOT);
if let Some(h) = tp { h.send(TpCommand::Decode { tokens: vec![next], positions: vec![position], slots: vec![SLOT] }); }
let logits = model.forward_decode_paged(&[next], &[position], &[SLOT], cache);
if let Some(h) = tp { h.wait(); }
if let Some(h) = tp {
h.send(TpCommand::Decode {
tokens: vec![next],
positions: vec![position],
slots: vec![SLOT],
});
}
let logits = chat_decode(model, decoder, &[next], &[position], &[SLOT], cache);
if let Some(h) = tp {
h.wait();
}
if tokenizer.is_eos(next) {
print_stream_text(
&tokenizer.flush_decode_stream(&mut decode_buffer),
@@ -757,7 +961,10 @@ fn generate_with_paged_cache(
print_stream_text("\n</think>\n\n", true, use_color);
}
io::stdout().flush().unwrap();
return (Finish::Stop { token_id: next }, tokenizer.decode(&answer_ids));
return (
Finish::Stop { token_id: next },
tokenizer.decode(&answer_ids),
);
}
if harmony_end_id == Some(next) {
// <|end|> closes current segment; if in final channel, we're done
@@ -768,7 +975,10 @@ fn generate_with_paged_cache(
);
if hstate == HarmonyState::InFinal {
io::stdout().flush().unwrap();
return (Finish::Stop { token_id: next }, tokenizer.decode(&answer_ids));
return (
Finish::Stop { token_id: next },
tokenizer.decode(&answer_ids),
);
}
// Closing a thinking (analysis/commentary) channel: emit the </think>
// marker so it renders like Qwen3's thinking block.
@@ -824,7 +1034,13 @@ fn generate_with_paged_cache(
// Analysis channel = the model's reasoning. With --think, show it as a
// thinking block (gray if color); otherwise suppress it (answer only).
if show_thinking {
print_generated_token(tokenizer, next, &mut decode_buffer, &mut in_thinking, use_color);
print_generated_token(
tokenizer,
next,
&mut decode_buffer,
&mut in_thinking,
use_color,
);
io::stdout().flush().unwrap();
}
next = pick(&logits, sampling, &history);
@@ -886,9 +1102,16 @@ fn append_text_to_cache(
if tokens.is_empty() || cache.seq_len(SLOT) + tokens.len() > max_seq_len {
return;
}
if let Some(h) = tp { h.send(TpCommand::Prefill { tokens: tokens.clone(), slot: SLOT }); }
if let Some(h) = tp {
h.send(TpCommand::Prefill {
tokens: tokens.clone(),
slot: SLOT,
});
}
let _ = model.forward_prefill_paged(&tokens, SLOT, cache);
if let Some(h) = tp { h.wait(); }
if let Some(h) = tp {
h.wait();
}
}
fn print_generated_token(
@@ -934,4 +1157,3 @@ fn print_stream_text(text: &str, in_thinking: bool, use_color: bool) {
print!("{text}");
}
}

View File

@@ -1,6 +1,6 @@
use std::io::{self, Write};
use std::path::PathBuf;
use xserv_model::{loader, KVCache, ModelConfig, PagedKVCache, BLOCK_SIZE};
use xserv_model::{BLOCK_SIZE, KVCache, ModelConfig, PagedKVCache, loader};
use xserv_tensor::{DType, Device};
use xserv_tokenizer::Tokenizer;
@@ -21,14 +21,21 @@ fn main() {
xserv_cuda::device::set_device(0).unwrap();
let info = xserv_cuda::device::device_info(0).unwrap();
eprintln!("GPU: {} ({} MB free)", info.name, info.free_memory / 1024 / 1024);
eprintln!(
"GPU: {} ({} MB free)",
info.name,
info.free_memory / 1024 / 1024
);
let config = ModelConfig::from_file(&model_dir.join("config.json"));
let model_type = config.model_type.as_deref().unwrap_or("unknown");
eprintln!(
"Model: {model_type}, layers={}, hidden={}, heads={}/{} kv, vocab={}",
config.num_layers(), config.hidden(), config.num_heads(),
config.num_kv_heads(), config.vocab_size
config.num_layers(),
config.hidden(),
config.num_heads(),
config.num_kv_heads(),
config.vocab_size
);
eprintln!("Loading weights...");
@@ -37,7 +44,11 @@ fn main() {
let is_qwen3 = model_type.contains("qwen");
let is_gpt_oss = model_type.contains("gpt_oss");
let dtype = if is_qwen3 || is_gpt_oss { DType::BF16 } else { DType::F32 };
let dtype = if is_qwen3 || is_gpt_oss {
DType::BF16
} else {
DType::F32
};
// Build model
enum Model {
@@ -60,10 +71,16 @@ fn main() {
print!("xserv> ");
io::stdout().flush().unwrap();
let mut input = String::new();
if io::stdin().read_line(&mut input).unwrap() == 0 { break; }
if io::stdin().read_line(&mut input).unwrap() == 0 {
break;
}
let input = input.trim();
if input.is_empty() { continue; }
if input == "quit" || input == "exit" { break; }
if input.is_empty() {
continue;
}
if input == "quit" || input == "exit" {
break;
}
let token_ids = tokenizer.encode(input);
@@ -73,12 +90,21 @@ fn main() {
let max_blocks_per_seq = (max_seq + BLOCK_SIZE - 1) / BLOCK_SIZE;
let total_blocks = max_blocks_per_seq + 64;
let mut paged_cache = PagedKVCache::new(
&config, total_blocks, 0, 4, max_blocks_per_seq, DType::BF16, 0,
&config,
total_blocks,
0,
4,
max_blocks_per_seq,
DType::BF16,
0,
);
let slot = 0;
paged_cache.register_sequence(slot).expect("register slot");
let model = match &model { Model::GptOss(m) => m, _ => unreachable!() };
let model = match &model {
Model::GptOss(m) => m,
_ => unreachable!(),
};
let logits = model.forward_prefill_paged(&token_ids, slot, &mut paged_cache);
let mut next = sample_greedy_last(&logits);
@@ -90,20 +116,28 @@ fn main() {
print!("{text}");
io::stdout().flush().unwrap();
if tokenizer.eos_token_id() == Some(next) { break; }
if tokenizer.eos_token_id() == Some(next) {
break;
}
let pos = paged_cache.seq_len(slot);
let logits = model.forward_decode_paged(
&[next], &[pos], &[slot], &mut paged_cache,
);
let logits = model.forward_decode_paged(&[next], &[pos], &[slot], &mut paged_cache);
next = sample_greedy_last(&logits);
}
println!();
paged_cache.free_sequence(slot);
} else {
let kv_heads = if is_qwen3 { config.num_kv_heads() } else { config.num_heads() };
let kv_heads = if is_qwen3 {
config.num_kv_heads()
} else {
config.num_heads()
};
let mut cache = KVCache::new(
config.num_layers(), kv_heads, config.head_dim(), dtype, Device::Cuda(0),
config.num_layers(),
kv_heads,
config.head_dim(),
dtype,
Device::Cuda(0),
);
let logits = match &model {
@@ -125,7 +159,9 @@ fn main() {
print!("{text}");
io::stdout().flush().unwrap();
if tokenizer.eos_token_id() == Some(next) { break; }
if tokenizer.eos_token_id() == Some(next) {
break;
}
let logits = match &model {
Model::GPT2(m) => m.forward_with_cache(&[next], &mut cache),
@@ -151,7 +187,9 @@ fn sample_greedy_last(logits: &xserv_tensor::Tensor) -> u32 {
let seq_len = logits.shape()[0];
let data = logits_cpu.as_slice::<bf16>();
let last = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size];
last.iter().enumerate()
last.iter()
.enumerate()
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
.map(|(i, _)| i as u32).unwrap()
.map(|(i, _)| i as u32)
.unwrap()
}

View File

@@ -88,23 +88,33 @@ impl ModelConfig {
}
pub fn hidden(&self) -> usize {
self.hidden_size.or(self.n_embd).expect("hidden_size or n_embd required")
self.hidden_size
.or(self.n_embd)
.expect("hidden_size or n_embd required")
}
pub fn num_heads(&self) -> usize {
self.num_attention_heads.or(self.n_head).expect("num_attention_heads or n_head required")
self.num_attention_heads
.or(self.n_head)
.expect("num_attention_heads or n_head required")
}
pub fn num_layers(&self) -> usize {
self.num_hidden_layers.or(self.n_layer).expect("num_hidden_layers or n_layer required")
self.num_hidden_layers
.or(self.n_layer)
.expect("num_hidden_layers or n_layer required")
}
pub fn max_seq_len(&self) -> usize {
self.max_position_embeddings.or(self.n_positions).unwrap_or(2048)
self.max_position_embeddings
.or(self.n_positions)
.unwrap_or(2048)
}
pub fn ffn_hidden(&self) -> usize {
self.intermediate_size.or(self.n_inner).unwrap_or(self.hidden() * 4)
self.intermediate_size
.or(self.n_inner)
.unwrap_or(self.hidden() * 4)
}
pub fn num_kv_heads(&self) -> usize {
@@ -112,7 +122,8 @@ impl ModelConfig {
}
pub fn head_dim(&self) -> usize {
self.explicit_head_dim.unwrap_or_else(|| self.hidden() / self.num_heads())
self.explicit_head_dim
.unwrap_or_else(|| self.hidden() / self.num_heads())
}
pub fn ln_eps(&self) -> f32 {

View File

@@ -18,19 +18,19 @@ use crate::kv_cache::GpuKVCache;
/// All buffers have stable GPU addresses for CUDA Graph replay.
struct DecodeBuffers {
// Hidden-size buffers: [1, hidden]
x: GpuBuffer, // running hidden state
normed: GpuBuffer, // rmsnorm output
attn_out: GpuBuffer, // attention output [1, num_heads, 1, head_dim]
attn_merged: GpuBuffer, // merge_heads output [1, hidden]
o_proj: GpuBuffer, // O projection output [1, hidden]
normed2: GpuBuffer, // post-attn norm output [1, hidden]
sum_out: GpuBuffer, // add_rmsnorm sum output [1, hidden]
down: GpuBuffer, // down projection output [1, hidden]
x: GpuBuffer, // running hidden state
normed: GpuBuffer, // rmsnorm output
attn_out: GpuBuffer, // attention output [1, num_heads, 1, head_dim]
attn_merged: GpuBuffer, // merge_heads output [1, hidden]
o_proj: GpuBuffer, // O projection output [1, hidden]
normed2: GpuBuffer, // post-attn norm output [1, hidden]
sum_out: GpuBuffer, // add_rmsnorm sum output [1, hidden]
down: GpuBuffer, // down projection output [1, hidden]
// QKV projection outputs
q_proj: GpuBuffer, // [1, num_heads * head_dim]
k_proj: GpuBuffer, // [1, num_kv_heads * head_dim]
v_proj: GpuBuffer, // [1, num_kv_heads * head_dim]
q_proj: GpuBuffer, // [1, num_heads * head_dim]
k_proj: GpuBuffer, // [1, num_kv_heads * head_dim]
v_proj: GpuBuffer, // [1, num_kv_heads * head_dim]
// Reshaped: [1, H, 1, D]
q_reshaped: GpuBuffer,
@@ -50,23 +50,23 @@ struct DecodeBuffers {
k_final: GpuBuffer,
// FFN intermediates
gate: GpuBuffer, // [1, intermediate]
up: GpuBuffer, // [1, intermediate]
silu_out: GpuBuffer, // [1, intermediate]
gate: GpuBuffer, // [1, intermediate]
up: GpuBuffer, // [1, intermediate]
silu_out: GpuBuffer, // [1, intermediate]
// GEMV fp32 accumulators (separate per output dimension)
fp32_hidden: GpuBuffer, // for hidden-sized GEMV outputs
fp32_q: GpuBuffer, // for Q projection
fp32_kv: GpuBuffer, // for K/V projection
fp32_intermediate: GpuBuffer,// for gate/up projections
fp32_vocab: GpuBuffer, // for lm_head
fp32_hidden: GpuBuffer, // for hidden-sized GEMV outputs
fp32_q: GpuBuffer, // for Q projection
fp32_kv: GpuBuffer, // for K/V projection
fp32_intermediate: GpuBuffer, // for gate/up projections
fp32_vocab: GpuBuffer, // for lm_head
// Token ID and position (GPU-resident, updated before replay)
token_id_gpu: GpuBuffer, // 4 bytes (u32)
position_gpu: GpuBuffer, // 4 bytes (u32)
token_id_gpu: GpuBuffer, // 4 bytes (u32)
position_gpu: GpuBuffer, // 4 bytes (u32)
// Final output
logits: GpuBuffer, // [1, vocab_size]
logits: GpuBuffer, // [1, vocab_size]
}
pub struct DecodeGraphState {
@@ -199,127 +199,296 @@ impl DecodeGraphState {
let cublas = cublas_handle();
// Set cuBLAS to use our stream
unsafe { dispatch::set_cublas_stream(cublas, s); }
unsafe {
dispatch::set_cublas_stream(cublas, s);
}
for (l, lw) in layers.iter().enumerate() {
// === Pre-attention graph ===
self.pre_attn_graphs[l].begin_capture(&self.stream).expect("begin pre-attn capture");
self.pre_attn_graphs[l]
.begin_capture(&self.stream)
.expect("begin pre-attn capture");
unsafe {
// RMSNorm
dispatch::rmsnorm_bf16(
self.buffers.x.as_ptr() as _, lw.input_norm, self.buffers.normed.as_mut_ptr() as _,
1, h, eps, s,
self.buffers.x.as_ptr() as _,
lw.input_norm,
self.buffers.normed.as_mut_ptr() as _,
1,
h,
eps,
s,
);
// Q projection (GEMV)
dispatch::gemv_bf16(
self.buffers.normed.as_ptr() as _, lw.q_proj_wt, self.buffers.q_proj.as_mut_ptr() as _,
self.buffers.normed.as_ptr() as _,
lw.q_proj_wt,
self.buffers.q_proj.as_mut_ptr() as _,
self.buffers.fp32_q.as_mut_ptr() as _,
h, nh * hd, s,
h,
nh * hd,
s,
);
// K projection (GEMV)
dispatch::gemv_bf16(
self.buffers.normed.as_ptr() as _, lw.k_proj_wt, self.buffers.k_proj.as_mut_ptr() as _,
self.buffers.normed.as_ptr() as _,
lw.k_proj_wt,
self.buffers.k_proj.as_mut_ptr() as _,
self.buffers.fp32_kv.as_mut_ptr() as _,
h, nkv * hd, s,
h,
nkv * hd,
s,
);
// V projection (GEMV)
dispatch::gemv_bf16(
self.buffers.normed.as_ptr() as _, lw.v_proj_wt, self.buffers.v_proj.as_mut_ptr() as _,
self.buffers.normed.as_ptr() as _,
lw.v_proj_wt,
self.buffers.v_proj.as_mut_ptr() as _,
self.buffers.fp32_kv.as_mut_ptr() as _,
h, nkv * hd, s,
h,
nkv * hd,
s,
);
// Reshape heads: [1, H*D] -> [1, H, 1, D]
dispatch::reshape_heads_bf16(self.buffers.q_proj.as_ptr() as _, self.buffers.q_reshaped.as_mut_ptr() as _, 1, nh, hd, s);
dispatch::reshape_heads_bf16(self.buffers.k_proj.as_ptr() as _, self.buffers.k_reshaped.as_mut_ptr() as _, 1, nkv, hd, s);
dispatch::reshape_heads_bf16(self.buffers.v_proj.as_ptr() as _, self.buffers.v_reshaped.as_mut_ptr() as _, 1, nkv, hd, s);
dispatch::reshape_heads_bf16(
self.buffers.q_proj.as_ptr() as _,
self.buffers.q_reshaped.as_mut_ptr() as _,
1,
nh,
hd,
s,
);
dispatch::reshape_heads_bf16(
self.buffers.k_proj.as_ptr() as _,
self.buffers.k_reshaped.as_mut_ptr() as _,
1,
nkv,
hd,
s,
);
dispatch::reshape_heads_bf16(
self.buffers.v_proj.as_ptr() as _,
self.buffers.v_reshaped.as_mut_ptr() as _,
1,
nkv,
hd,
s,
);
// QK norm (head-level rmsnorm: treat [1,H,1,D] as [H, D])
dispatch::rmsnorm_bf16(self.buffers.q_reshaped.as_ptr() as _, lw.q_norm, self.buffers.q_normed.as_mut_ptr() as _, nh, hd, eps, s);
dispatch::rmsnorm_bf16(self.buffers.k_reshaped.as_ptr() as _, lw.k_norm, self.buffers.k_normed.as_mut_ptr() as _, nkv, hd, eps, s);
dispatch::rmsnorm_bf16(
self.buffers.q_reshaped.as_ptr() as _,
lw.q_norm,
self.buffers.q_normed.as_mut_ptr() as _,
nh,
hd,
eps,
s,
);
dispatch::rmsnorm_bf16(
self.buffers.k_reshaped.as_ptr() as _,
lw.k_norm,
self.buffers.k_normed.as_mut_ptr() as _,
nkv,
hd,
eps,
s,
);
// Transpose for RoPE: [1,H,1,D] -> [1,H,D]
dispatch::transpose_hsd_to_shd_bf16(self.buffers.q_normed.as_ptr() as _, self.buffers.q_rope.as_mut_ptr() as _, 1, nh, hd, s);
dispatch::transpose_hsd_to_shd_bf16(self.buffers.k_normed.as_ptr() as _, self.buffers.k_rope.as_mut_ptr() as _, 1, nkv, hd, s);
dispatch::transpose_hsd_to_shd_bf16(
self.buffers.q_normed.as_ptr() as _,
self.buffers.q_rope.as_mut_ptr() as _,
1,
nh,
hd,
s,
);
dispatch::transpose_hsd_to_shd_bf16(
self.buffers.k_normed.as_ptr() as _,
self.buffers.k_rope.as_mut_ptr() as _,
1,
nkv,
hd,
s,
);
// RoPE (in-place, reads position_gpu)
dispatch::rope_bf16(self.buffers.q_rope.as_mut_ptr() as _, rope_cos, rope_sin, self.buffers.position_gpu.as_ptr() as _, 1, nh, hd, s);
dispatch::rope_bf16(self.buffers.k_rope.as_mut_ptr() as _, rope_cos, rope_sin, self.buffers.position_gpu.as_ptr() as _, 1, nkv, hd, s);
dispatch::rope_bf16(
self.buffers.q_rope.as_mut_ptr() as _,
rope_cos,
rope_sin,
self.buffers.position_gpu.as_ptr() as _,
1,
nh,
hd,
s,
);
dispatch::rope_bf16(
self.buffers.k_rope.as_mut_ptr() as _,
rope_cos,
rope_sin,
self.buffers.position_gpu.as_ptr() as _,
1,
nkv,
hd,
s,
);
// Transpose back: [1,H,D] -> [1,H,1,D]
dispatch::transpose_shd_to_hsd_bf16(self.buffers.q_rope.as_ptr() as _, self.buffers.q_final.as_mut_ptr() as _, 1, nh, hd, s);
dispatch::transpose_shd_to_hsd_bf16(self.buffers.k_rope.as_ptr() as _, self.buffers.k_final.as_mut_ptr() as _, 1, nkv, hd, s);
dispatch::transpose_shd_to_hsd_bf16(
self.buffers.q_rope.as_ptr() as _,
self.buffers.q_final.as_mut_ptr() as _,
1,
nh,
hd,
s,
);
dispatch::transpose_shd_to_hsd_bf16(
self.buffers.k_rope.as_ptr() as _,
self.buffers.k_final.as_mut_ptr() as _,
1,
nkv,
hd,
s,
);
}
self.pre_attn_graphs[l].end_capture(&self.stream).expect("end pre-attn capture");
self.pre_attn_graphs[l]
.end_capture(&self.stream)
.expect("end pre-attn capture");
// === Post-attention graph ===
self.post_attn_graphs[l].begin_capture(&self.stream).expect("begin post-attn capture");
self.post_attn_graphs[l]
.begin_capture(&self.stream)
.expect("begin post-attn capture");
unsafe {
// Merge heads: [1,H,1,D] -> [1, hidden]
// attn_out is written by ungraphed attention
dispatch::merge_heads_bf16(self.buffers.attn_out.as_ptr() as _, self.buffers.attn_merged.as_mut_ptr() as _, 1, nh, hd, s);
dispatch::merge_heads_bf16(
self.buffers.attn_out.as_ptr() as _,
self.buffers.attn_merged.as_mut_ptr() as _,
1,
nh,
hd,
s,
);
// O projection
dispatch::gemv_bf16(
self.buffers.attn_merged.as_ptr() as _, lw.o_proj_wt, self.buffers.o_proj.as_mut_ptr() as _,
self.buffers.attn_merged.as_ptr() as _,
lw.o_proj_wt,
self.buffers.o_proj.as_mut_ptr() as _,
self.buffers.fp32_hidden.as_mut_ptr() as _,
nh * hd, h, s,
nh * hd,
h,
s,
);
// Fused Add+RMSNorm: normed2 = rmsnorm(o_proj + x), sum_out = o_proj + x
dispatch::add_rmsnorm_bf16(
self.buffers.o_proj.as_ptr() as _, self.buffers.x.as_ptr() as _, lw.post_norm,
self.buffers.normed2.as_mut_ptr() as _, self.buffers.sum_out.as_mut_ptr() as _,
1, h, eps, s,
self.buffers.o_proj.as_ptr() as _,
self.buffers.x.as_ptr() as _,
lw.post_norm,
self.buffers.normed2.as_mut_ptr() as _,
self.buffers.sum_out.as_mut_ptr() as _,
1,
h,
eps,
s,
);
// Gate projection
dispatch::gemv_bf16(
self.buffers.normed2.as_ptr() as _, lw.gate_proj_wt, self.buffers.gate.as_mut_ptr() as _,
self.buffers.normed2.as_ptr() as _,
lw.gate_proj_wt,
self.buffers.gate.as_mut_ptr() as _,
self.buffers.fp32_intermediate.as_mut_ptr() as _,
h, inter, s,
h,
inter,
s,
);
// Up projection
dispatch::gemv_bf16(
self.buffers.normed2.as_ptr() as _, lw.up_proj_wt, self.buffers.up.as_mut_ptr() as _,
self.buffers.normed2.as_ptr() as _,
lw.up_proj_wt,
self.buffers.up.as_mut_ptr() as _,
self.buffers.fp32_intermediate.as_mut_ptr() as _,
h, inter, s,
h,
inter,
s,
);
// Fused SiLU x Mul
dispatch::silu_mul_bf16(self.buffers.gate.as_ptr() as _, self.buffers.up.as_ptr() as _, self.buffers.silu_out.as_mut_ptr() as _, inter, s);
dispatch::silu_mul_bf16(
self.buffers.gate.as_ptr() as _,
self.buffers.up.as_ptr() as _,
self.buffers.silu_out.as_mut_ptr() as _,
inter,
s,
);
// Down projection
dispatch::gemv_bf16(
self.buffers.silu_out.as_ptr() as _, lw.down_proj_wt, self.buffers.down.as_mut_ptr() as _,
self.buffers.silu_out.as_ptr() as _,
lw.down_proj_wt,
self.buffers.down.as_mut_ptr() as _,
self.buffers.fp32_hidden.as_mut_ptr() as _,
inter, h, s,
inter,
h,
s,
);
// x = sum_out + down (residual connection for next layer)
dispatch::add_bf16(self.buffers.sum_out.as_ptr() as _, self.buffers.down.as_ptr() as _, self.buffers.x.as_mut_ptr() as _, h, s);
dispatch::add_bf16(
self.buffers.sum_out.as_ptr() as _,
self.buffers.down.as_ptr() as _,
self.buffers.x.as_mut_ptr() as _,
h,
s,
);
}
self.post_attn_graphs[l].end_capture(&self.stream).expect("end post-attn capture");
self.post_attn_graphs[l]
.end_capture(&self.stream)
.expect("end post-attn capture");
}
// === Final graph: norm + lm_head ===
self.final_graph.begin_capture(&self.stream).expect("begin final capture");
self.final_graph
.begin_capture(&self.stream)
.expect("begin final capture");
unsafe {
dispatch::rmsnorm_bf16(self.buffers.x.as_ptr() as _, norm_weight, self.buffers.normed.as_mut_ptr() as _, 1, h, eps, s);
dispatch::rmsnorm_bf16(
self.buffers.x.as_ptr() as _,
norm_weight,
self.buffers.normed.as_mut_ptr() as _,
1,
h,
eps,
s,
);
dispatch::gemv_bf16(
self.buffers.normed.as_ptr() as _, lm_head_wt, self.buffers.logits.as_mut_ptr() as _,
self.buffers.normed.as_ptr() as _,
lm_head_wt,
self.buffers.logits.as_mut_ptr() as _,
self.buffers.fp32_vocab.as_mut_ptr() as _,
h, vocab, s,
h,
vocab,
s,
);
}
self.final_graph.end_capture(&self.stream).expect("end final capture");
self.final_graph
.end_capture(&self.stream)
.expect("end final capture");
// Reset cuBLAS back to null stream
unsafe { dispatch::set_cublas_stream(cublas, std::ptr::null_mut()); }
unsafe {
dispatch::set_cublas_stream(cublas, std::ptr::null_mut());
}
self.captured = true;
}
@@ -343,8 +512,14 @@ impl DecodeGraphState {
let es = 2usize; // BF16
// Upload token ID and position to fixed GPU buffers
self.buffers.token_id_gpu.copy_from_host(&token_id.to_le_bytes()).unwrap();
self.buffers.position_gpu.copy_from_host(&position.to_le_bytes()).unwrap();
self.buffers
.token_id_gpu
.copy_from_host(&token_id.to_le_bytes())
.unwrap();
self.buffers
.position_gpu
.copy_from_host(&position.to_le_bytes())
.unwrap();
// Embedding (outside graph since token_id changes each step)
unsafe {
@@ -352,13 +527,18 @@ impl DecodeGraphState {
embed_table,
self.buffers.token_id_gpu.as_ptr() as _,
self.buffers.x.as_mut_ptr() as _,
1, hidden_size, vocab_size, s,
1,
hidden_size,
vocab_size,
s,
);
}
for l in 0..self.num_layers {
// Pre-attention graph (norm + QKV + reshape + QK-norm + RoPE)
self.pre_attn_graphs[l].launch(&self.stream).expect("launch pre-attn graph");
self.pre_attn_graphs[l]
.launch(&self.stream)
.expect("launch pre-attn graph");
// Ungraphed: KV cache append
// k_final shape: [1, num_kv_heads, 1, head_dim] (after RoPE pipeline)
@@ -402,9 +582,13 @@ impl DecodeGraphState {
k_full.data_ptr() as _,
v_full.data_ptr() as _,
self.buffers.attn_out.as_mut_ptr() as _,
1, nh as i32, nkv as i32,
kv_len, hd as i32,
scale, s,
1,
nh as i32,
nkv as i32,
kv_len,
hd as i32,
scale,
s,
);
}
@@ -412,11 +596,15 @@ impl DecodeGraphState {
self.stream.synchronize().expect("sync before post-attn");
// Post-attention graph (merge + O-proj + add_rmsnorm + FFN + residual)
self.post_attn_graphs[l].launch(&self.stream).expect("launch post-attn graph");
self.post_attn_graphs[l]
.launch(&self.stream)
.expect("launch post-attn graph");
}
// Final graph (norm + lm_head)
self.final_graph.launch(&self.stream).expect("launch final graph");
self.final_graph
.launch(&self.stream)
.expect("launch final graph");
// Sync to ensure logits are ready
self.stream.synchronize().expect("sync after decode");

View File

@@ -31,7 +31,7 @@ struct GPT2Block {
pub struct KVCache {
// Per layer, per head: raw bytes (works for both f32 and bf16)
k: Vec<Vec<Vec<u8>>>, // [num_layers][num_heads][seq_len * head_dim * elem_size]
k: Vec<Vec<Vec<u8>>>, // [num_layers][num_heads][seq_len * head_dim * elem_size]
v: Vec<Vec<Vec<u8>>>,
len: usize,
num_heads: usize,
@@ -42,7 +42,13 @@ pub struct KVCache {
}
impl KVCache {
pub fn new(num_layers: usize, num_heads: usize, head_dim: usize, dtype: DType, device: Device) -> Self {
pub fn new(
num_layers: usize,
num_heads: usize,
head_dim: usize,
dtype: DType,
device: Device,
) -> Self {
Self {
k: (0..num_layers).map(|_| vec![vec![]; num_heads]).collect(),
v: (0..num_layers).map(|_| vec![vec![]; num_heads]).collect(),
@@ -55,10 +61,18 @@ impl KVCache {
}
}
pub fn seq_len(&self) -> usize { self.len }
pub fn seq_len(&self) -> usize {
self.len
}
/// Append from a CPU tensor with shape [1, H, new_tokens, D].
pub fn append_kv_tensor(&mut self, layer: usize, k_cpu: &Tensor, v_cpu: &Tensor, new_tokens: usize) {
pub fn append_kv_tensor(
&mut self,
layer: usize,
k_cpu: &Tensor,
v_cpu: &Tensor,
new_tokens: usize,
) {
let hd = self.head_dim;
let es = self.elem_size;
let k_bytes = k_cpu.storage().as_cpu_bytes();
@@ -118,7 +132,8 @@ impl GPT2 {
pub fn from_weights(config: ModelConfig, mut w: HashMap<String, Tensor>) -> Self {
crate::init_kernels();
let take = |w: &mut HashMap<String, Tensor>, name: &str| -> Tensor {
w.remove(name).unwrap_or_else(|| panic!("missing weight: {name}"))
w.remove(name)
.unwrap_or_else(|| panic!("missing weight: {name}"))
};
let wte = take(&mut w, "wte.weight");
@@ -147,7 +162,15 @@ impl GPT2 {
});
}
Self { config, wte, wpe, layers, ln_f_g, ln_f_b, lm_head }
Self {
config,
wte,
wpe,
layers,
ln_f_g,
ln_f_b,
lm_head,
}
}
/// Full forward pass without KV cache (for testing / correctness comparison).
@@ -179,14 +202,22 @@ impl GPT2 {
let head_dim = self.config.head_dim();
let tok_emb = embedding(&self.wte, token_ids);
let pos_ids: Vec<u32> = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect();
let pos_ids: Vec<u32> = (pos_offset..pos_offset + new_tokens)
.map(|p| p as u32)
.collect();
let pos_emb = embedding(&self.wpe, &pos_ids);
let mut x = add_tensors(&tok_emb, &pos_emb);
for (layer_idx, layer) in self.layers.iter().enumerate() {
x = self.transformer_block(
layer, &x, Some((cache, layer_idx)),
pos_offset, new_tokens, num_heads, head_dim, hidden,
layer,
&x,
Some((cache, layer_idx)),
pos_offset,
new_tokens,
num_heads,
head_dim,
hidden,
);
}
@@ -199,7 +230,7 @@ impl GPT2 {
layer: &GPT2Block,
x: &Tensor,
cache: Option<(&mut KVCache, usize)>,
pos_offset: usize,
_pos_offset: usize,
new_tokens: usize,
num_heads: usize,
head_dim: usize,
@@ -238,7 +269,11 @@ impl GPT2 {
fn linear(x: &Tensor, weight: &Tensor, bias: Option<&Tensor>) -> Tensor {
let out = matmul_2d(x, weight);
if let Some(b) = bias { add_bias(&out, b) } else { out }
if let Some(b) = bias {
add_bias(&out, b)
} else {
out
}
}
fn matmul_2d(a: &Tensor, b: &Tensor) -> Tensor {
@@ -277,7 +312,12 @@ fn add_bias(x: &Tensor, bias: &Tensor) -> Tensor {
}
}
fn split_qkv(qkv: &Tensor, num_heads: usize, head_dim: usize, seq_len: usize) -> (Tensor, Tensor, Tensor) {
fn split_qkv(
qkv: &Tensor,
num_heads: usize,
head_dim: usize,
seq_len: usize,
) -> (Tensor, Tensor, Tensor) {
let hidden = num_heads * head_dim;
let qkv_cpu = qkv.to_device(Device::Cpu);
let device = qkv.device();
@@ -294,14 +334,21 @@ fn split_qkv(qkv: &Tensor, num_heads: usize, head_dim: usize, seq_len: usize) ->
for h in 0..num_heads {
let src_off = h * head_dim;
let dst_off = (h * seq_len + s) * head_dim;
q_data[dst_off..dst_off + head_dim].copy_from_slice(&row[src_off..src_off + head_dim]);
k_data[dst_off..dst_off + head_dim].copy_from_slice(&row[hidden + src_off..hidden + src_off + head_dim]);
v_data[dst_off..dst_off + head_dim].copy_from_slice(&row[2 * hidden + src_off..2 * hidden + src_off + head_dim]);
q_data[dst_off..dst_off + head_dim]
.copy_from_slice(&row[src_off..src_off + head_dim]);
k_data[dst_off..dst_off + head_dim]
.copy_from_slice(&row[hidden + src_off..hidden + src_off + head_dim]);
v_data[dst_off..dst_off + head_dim].copy_from_slice(
&row[2 * hidden + src_off..2 * hidden + src_off + head_dim],
);
}
}
let q = Tensor::from_slice(&q_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
let k = Tensor::from_slice(&k_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
let v = Tensor::from_slice(&v_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
let q =
Tensor::from_slice(&q_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
let k =
Tensor::from_slice(&k_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
let v =
Tensor::from_slice(&v_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
(q, k, v)
}
DType::BF16 => {
@@ -314,14 +361,21 @@ fn split_qkv(qkv: &Tensor, num_heads: usize, head_dim: usize, seq_len: usize) ->
for h in 0..num_heads {
let src_off = h * head_dim;
let dst_off = (h * seq_len + s) * head_dim;
q_data[dst_off..dst_off + head_dim].copy_from_slice(&row[src_off..src_off + head_dim]);
k_data[dst_off..dst_off + head_dim].copy_from_slice(&row[hidden + src_off..hidden + src_off + head_dim]);
v_data[dst_off..dst_off + head_dim].copy_from_slice(&row[2 * hidden + src_off..2 * hidden + src_off + head_dim]);
q_data[dst_off..dst_off + head_dim]
.copy_from_slice(&row[src_off..src_off + head_dim]);
k_data[dst_off..dst_off + head_dim]
.copy_from_slice(&row[hidden + src_off..hidden + src_off + head_dim]);
v_data[dst_off..dst_off + head_dim].copy_from_slice(
&row[2 * hidden + src_off..2 * hidden + src_off + head_dim],
);
}
}
let q = Tensor::from_slice(&q_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
let k = Tensor::from_slice(&k_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
let v = Tensor::from_slice(&v_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
let q =
Tensor::from_slice(&q_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
let k =
Tensor::from_slice(&k_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
let v =
Tensor::from_slice(&v_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
(q, k, v)
}
_ => panic!("unsupported dtype {:?} in split_qkv", dtype),
@@ -343,7 +397,8 @@ fn merge_heads(x: &Tensor, seq_len: usize, hidden: usize) -> Tensor {
for h in 0..num_heads {
let src_off = (h * seq_len + s) * head_dim;
let dst_off = s * hidden + h * head_dim;
out[dst_off..dst_off + head_dim].copy_from_slice(&src[src_off..src_off + head_dim]);
out[dst_off..dst_off + head_dim]
.copy_from_slice(&src[src_off..src_off + head_dim]);
}
}
Tensor::from_slice(&out, &[seq_len, hidden]).to_device(device)
@@ -355,7 +410,8 @@ fn merge_heads(x: &Tensor, seq_len: usize, hidden: usize) -> Tensor {
for h in 0..num_heads {
let src_off = (h * seq_len + s) * head_dim;
let dst_off = s * hidden + h * head_dim;
out[dst_off..dst_off + head_dim].copy_from_slice(&src[src_off..src_off + head_dim]);
out[dst_off..dst_off + head_dim]
.copy_from_slice(&src[src_off..src_off + head_dim]);
}
}
Tensor::from_slice(&out, &[seq_len, hidden]).to_device(device)
@@ -372,7 +428,8 @@ pub fn sample_greedy(logits: &Tensor) -> u32 {
let vocab_size = logits.shape()[1];
let seq_len = logits.shape()[0];
let last_row = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size];
last_row.iter()
last_row
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.map(|(idx, _)| idx as u32)

View File

@@ -1,6 +1,6 @@
use half::bf16;
use std::collections::HashMap;
use std::ffi::c_void;
use half::bf16;
use xserv_kernels::*;
use xserv_tensor::{Device, Tensor};
@@ -49,10 +49,10 @@ struct GptOssBlock {
expert_down_bias: Tensor, // [local_experts, hidden]
// FP8 quantized expert weights (Some when running FP8 W8A8)
// Transposed layout [E, N, K] for cuBLASLt FP8 (Blackwell requires transA=T)
expert_gate_up_fp8: Option<Tensor>, // [local_experts, 2*inter, hidden] FP8E4M3
expert_gate_up_scale: Option<Tensor>,// [local_experts] F32
expert_down_fp8: Option<Tensor>, // [local_experts, hidden, inter] FP8E4M3
expert_down_scale: Option<Tensor>, // [local_experts] F32
expert_gate_up_fp8: Option<Tensor>, // [local_experts, 2*inter, hidden] FP8E4M3
expert_gate_up_scale: Option<Tensor>, // [local_experts] F32
expert_down_fp8: Option<Tensor>, // [local_experts, hidden, inter] FP8E4M3
expert_down_scale: Option<Tensor>, // [local_experts] F32
// MXFP4 W4A16 expert weights (Some when running 4-bit weight-only).
// (packed [E, N, K/2] u8, scales [E, N, K/32] u8) in [E, N, K] layout.
expert_gate_up_mxfp4: Option<(Tensor, Tensor)>,
@@ -79,16 +79,23 @@ impl GptOss {
crate::init_kernels();
let dev = Device::Cuda(device);
let take = |w: &mut HashMap<String, Tensor>, name: &str| -> Tensor {
w.remove(name).unwrap_or_else(|| panic!("missing weight: {name}"))
w.remove(name)
.unwrap_or_else(|| panic!("missing weight: {name}"))
};
let repl = |t: Tensor| -> Tensor { t.to_device(dev) };
// column-parallel: shard rows of [out, in], transpose → [in, out/world]
let col = |t: Tensor| -> Tensor {
shard_rows(&t, rank, world).to_device(dev).transpose(0, 1).contiguous()
shard_rows(&t, rank, world)
.to_device(dev)
.transpose(0, 1)
.contiguous()
};
// row-parallel: shard cols of [out, in], transpose → [in/world, out]
let row = |t: Tensor| -> Tensor {
shard_cols(&t, rank, world).to_device(dev).transpose(0, 1).contiguous()
shard_cols(&t, rank, world)
.to_device(dev)
.transpose(0, 1)
.contiguous()
};
// Bias sharding helpers
let col_bias = |t: Tensor| -> Tensor { shard_1d(&t, rank, world).to_device(dev) };
@@ -97,7 +104,9 @@ impl GptOss {
let embed_tokens = repl(take(&mut w, "model.embed_tokens.weight"));
let norm = repl(take(&mut w, "model.norm.weight"));
let norm_bias = w.remove("model.norm.bias").map(|t| repl(t));
let lm_head_t = repl(take(&mut w, "lm_head.weight")).transpose(0, 1).contiguous();
let lm_head_t = repl(take(&mut w, "lm_head.weight"))
.transpose(0, 1)
.contiguous();
let head_dim = config.head_dim();
let rope_theta = config.rope_theta.unwrap_or(150000.0);
@@ -176,15 +185,30 @@ impl GptOss {
// MXFP4 stores 4-bit weights in an FP8E4M3 byte container (same dtype
// as FP8), so distinguish by the scale rank: FP8 scale is 1-D [E],
// MXFP4 scale is 3-D [E, N, K/32].
let is_mxfp4 = gate_up_scale.as_ref().map(|s| s.ndim() == 3).unwrap_or(false);
let is_mxfp4 = gate_up_scale
.as_ref()
.map(|s| s.ndim() == 3)
.unwrap_or(false);
let is_fp8 = !is_mxfp4 && gate_up_3d.dtype() == xserv_tensor::DType::FP8E4M3;
let mut expert_gate_up_mxfp4: Option<(Tensor, Tensor)> = None;
let mut expert_down_mxfp4: Option<(Tensor, Tensor)> = None;
let inter2 = if is_mxfp4 { gate_up_3d.shape()[1] } else { gate_up_3d.shape()[2] }; // 2*inter (N)
let hidden = if is_mxfp4 { gate_up_3d.shape()[2] * 2 } else { gate_up_3d.shape()[1] };
let inter = if is_mxfp4 { down_3d.shape()[2] * 2 } else { down_3d.shape()[1] };
let inter2 = if is_mxfp4 {
gate_up_3d.shape()[1]
} else {
gate_up_3d.shape()[2]
}; // 2*inter (N)
let hidden = if is_mxfp4 {
gate_up_3d.shape()[2] * 2
} else {
gate_up_3d.shape()[1]
};
let inter = if is_mxfp4 {
down_3d.shape()[2] * 2
} else {
down_3d.shape()[1]
};
// Slice the rank's range of experts as contiguous 3D tensors on GPU
let expert_gate_up_wt;
@@ -199,10 +223,38 @@ impl GptOss {
// + scales [E, N, K/32]. Slice this rank's experts (raw bytes).
let gu_s = gate_up_scale.expect("MXFP4 model missing gate_up_proj_scale");
let d_s = down_scale.expect("MXFP4 model missing down_proj_scale");
let gu_packed = slice_expert_range_3d_raw(&gate_up_3d, expert_start, local_experts, inter2, hidden / 2).to_device(dev);
let gu_scl = slice_expert_range_3d_raw(&gu_s, expert_start, local_experts, inter2, hidden / 32).to_device(dev);
let dn_packed = slice_expert_range_3d_raw(&down_3d, expert_start, local_experts, hidden, inter / 2).to_device(dev);
let dn_scl = slice_expert_range_3d_raw(&d_s, expert_start, local_experts, hidden, inter / 32).to_device(dev);
let gu_packed = slice_expert_range_3d_raw(
&gate_up_3d,
expert_start,
local_experts,
inter2,
hidden / 2,
)
.to_device(dev);
let gu_scl = slice_expert_range_3d_raw(
&gu_s,
expert_start,
local_experts,
inter2,
hidden / 32,
)
.to_device(dev);
let dn_packed = slice_expert_range_3d_raw(
&down_3d,
expert_start,
local_experts,
hidden,
inter / 2,
)
.to_device(dev);
let dn_scl = slice_expert_range_3d_raw(
&d_s,
expert_start,
local_experts,
hidden,
inter / 32,
)
.to_device(dev);
expert_gate_up_mxfp4 = Some((gu_packed, gu_scl));
expert_down_mxfp4 = Some((dn_packed, dn_scl));
expert_gate_up_fp8 = None;
@@ -214,36 +266,65 @@ impl GptOss {
} else if is_fp8 {
// FP8 W8A8 path: load and TRANSPOSE weights for cuBLASLt (requires transA=T on Blackwell).
// Original: [E, K, N] → Transposed: [E, N, K]
let gu_sliced = slice_expert_range_3d_raw(&gate_up_3d, expert_start, local_experts, hidden, inter2);
let dn_sliced = slice_expert_range_3d_raw(&down_3d, expert_start, local_experts, inter, hidden);
expert_gate_up_fp8 = Some(transpose_3d_inner_raw(&gu_sliced, local_experts, hidden, inter2).to_device(dev));
expert_down_fp8 = Some(transpose_3d_inner_raw(&dn_sliced, local_experts, inter, hidden).to_device(dev));
let gu_sliced = slice_expert_range_3d_raw(
&gate_up_3d,
expert_start,
local_experts,
hidden,
inter2,
);
let dn_sliced =
slice_expert_range_3d_raw(&down_3d, expert_start, local_experts, inter, hidden);
expert_gate_up_fp8 = Some(
transpose_3d_inner_raw(&gu_sliced, local_experts, hidden, inter2)
.to_device(dev),
);
expert_down_fp8 = Some(
transpose_3d_inner_raw(&dn_sliced, local_experts, inter, hidden).to_device(dev),
);
// Scales: [num_experts] F32 → slice to [local_experts]
let gu_s = gate_up_scale.expect("FP8 model missing gate_up_proj_scale");
let d_s = down_scale.expect("FP8 model missing down_proj_scale");
expert_gate_up_scale_gpu = Some(slice_scale_range(&gu_s, expert_start, local_experts).to_device(dev));
expert_down_scale_gpu = Some(slice_scale_range(&d_s, expert_start, local_experts).to_device(dev));
expert_gate_up_scale_gpu =
Some(slice_scale_range(&gu_s, expert_start, local_experts).to_device(dev));
expert_down_scale_gpu =
Some(slice_scale_range(&d_s, expert_start, local_experts).to_device(dev));
// Dummy BF16 tensors (never read in FP8 path)
expert_gate_up_wt = Tensor::empty(&[1, 1, 1], xserv_tensor::DType::BF16, dev);
expert_down_wt = Tensor::empty(&[1, 1, 1], xserv_tensor::DType::BF16, dev);
} else {
// BF16 path: existing behavior
expert_gate_up_wt = slice_expert_range_3d(&gate_up_3d, expert_start, local_experts, hidden, inter2).to_device(dev);
expert_down_wt = slice_expert_range_3d(&down_3d, expert_start, local_experts, inter, hidden).to_device(dev);
expert_gate_up_wt =
slice_expert_range_3d(&gate_up_3d, expert_start, local_experts, hidden, inter2)
.to_device(dev);
expert_down_wt =
slice_expert_range_3d(&down_3d, expert_start, local_experts, inter, hidden)
.to_device(dev);
expert_gate_up_fp8 = None;
expert_gate_up_scale_gpu = None;
expert_down_fp8 = None;
expert_down_scale_gpu = None;
}
let expert_gate_up_bias = slice_expert_range_2d(&gate_up_bias_2d, expert_start, local_experts, inter2).to_device(dev);
let expert_down_bias = slice_expert_range_2d(&down_bias_2d, expert_start, local_experts, hidden).to_device(dev);
let expert_gate_up_bias =
slice_expert_range_2d(&gate_up_bias_2d, expert_start, local_experts, inter2)
.to_device(dev);
let expert_down_bias =
slice_expert_range_2d(&down_bias_2d, expert_start, local_experts, hidden)
.to_device(dev);
xserv_cuda::allocator::cached_trim();
let input_norm = repl(take(&mut w, &format!("{p}.input_layernorm.weight")));
let input_norm_bias = w.remove(&format!("{p}.input_layernorm.bias")).map(|t| repl(t));
let post_norm = repl(take(&mut w, &format!("{p}.post_attention_layernorm.weight")));
let post_norm_bias = w.remove(&format!("{p}.post_attention_layernorm.bias")).map(|t| repl(t));
let input_norm_bias = w
.remove(&format!("{p}.input_layernorm.bias"))
.map(|t| repl(t));
let post_norm = repl(take(
&mut w,
&format!("{p}.post_attention_layernorm.weight"),
));
let post_norm_bias = w
.remove(&format!("{p}.post_attention_layernorm.bias"))
.map(|t| repl(t));
layers.push(GptOssBlock {
input_norm,
@@ -283,17 +364,27 @@ impl GptOss {
let local_num_kv_heads = config.num_kv_heads() / world;
let has_norm_bias = norm_bias.is_some();
let is_fp8 = layers.first().map(|l| l.expert_gate_up_fp8.is_some()).unwrap_or(false);
let is_mxfp4 = layers.first().map(|l| l.expert_gate_up_mxfp4.is_some()).unwrap_or(false);
let is_fp8 = layers
.first()
.map(|l| l.expert_gate_up_fp8.is_some())
.unwrap_or(false);
let is_mxfp4 = layers
.first()
.map(|l| l.expert_gate_up_mxfp4.is_some())
.unwrap_or(false);
if rank == 0 {
if has_norm_bias {
eprintln!("gpt-oss: detected LayerNorm bias — using LayerNorm instead of RMSNorm");
}
if is_fp8 {
eprintln!("gpt-oss: FP8 E4M3 quantized expert weights detected (W8A8 cuBLASLt mode)");
eprintln!(
"gpt-oss: FP8 E4M3 quantized expert weights detected (W8A8 cuBLASLt mode)"
);
}
if is_mxfp4 {
eprintln!("gpt-oss: MXFP4 quantized expert weights detected (W4A16 fused-GEMV mode)");
eprintln!(
"gpt-oss: MXFP4 quantized expert weights detected (W4A16 fused-GEMV mode)"
);
}
}
@@ -341,7 +432,13 @@ impl GptOss {
}
#[inline]
fn add_norm(x: &Tensor, residual: &Tensor, weight: &Tensor, bias: &Option<Tensor>, eps: f32) -> (Tensor, Tensor) {
fn add_norm(
x: &Tensor,
residual: &Tensor,
weight: &Tensor,
bias: &Option<Tensor>,
eps: f32,
) -> (Tensor, Tensor) {
match bias {
Some(b) => {
let sum = xserv_kernels::add(x, residual);
@@ -373,24 +470,62 @@ impl GptOss {
assert_eq!(seq_slots.len(), batch);
assert!(batch > 0);
let num_heads = self.local_num_heads;
let num_kv_heads = self.local_num_kv_heads;
let head_dim = self.config.head_dim();
let eps = self.norm_eps();
self.decode_prepare(positions, seq_slots, paged_cache);
// Upload token ids + positions, then run the pure-GPU core.
let ids_gpu = upload_u32(tokens);
let positions_u32: Vec<u32> = positions.iter().map(|&p| p as u32).collect();
let pos_gpu = upload_u32(&positions_u32);
let logits = self.decode_core(
ids_gpu.as_ptr() as *const c_void,
pos_gpu.as_ptr() as *const c_void,
batch,
paged_cache,
);
for &slot in seq_slots {
paged_cache.advance_seq_len(slot, 1);
}
logits
}
/// Host-side per-step cache bookkeeping: block allocation + uploading
/// block tables / context lens to their (stable-address) GPU buffers.
/// Runs OUTSIDE the CUDA-graph captured region.
pub fn decode_prepare(
&self,
positions: &[usize],
seq_slots: &[usize],
paged_cache: &mut PagedKVCache,
) {
let kv_lens: Vec<i32> = positions.iter().map(|&p| (p + 1) as i32).collect();
for (b, &slot) in seq_slots.iter().enumerate() {
paged_cache.ensure_capacity(slot, positions[b] + 1);
}
paged_cache.sync_active_batch_with_lens(seq_slots, &kv_lens);
}
/// The pure-GPU decode step: embedding → 24 layers → final norm → logits.
/// Token ids and positions are read from device buffers; every other input
/// (weights, KV pools, block table, context lens) has a stable address —
/// which is exactly what makes this region CUDA-graph capturable.
pub fn decode_core(
&self,
ids_gpu: *const c_void,
pos_gpu: *const c_void,
batch: usize,
paged_cache: &mut PagedKVCache,
) -> Tensor {
let num_heads = self.local_num_heads;
let num_kv_heads = self.local_num_kv_heads;
let head_dim = self.config.head_dim();
let eps = self.norm_eps();
let bt_ptr = paged_cache.block_table_gpu().as_ptr() as *const i32;
let cl_ptr = paged_cache.context_lens_gpu().as_ptr() as *const i32;
let max_blocks = paged_cache.max_blocks_per_seq();
let positions_u32: Vec<u32> = positions.iter().map(|&p| p as u32).collect();
let mut x = embedding(&self.embed_tokens, tokens);
let mut x = embedding_device_ids(&self.embed_tokens, ids_gpu, batch);
for (layer_idx, layer) in self.layers.iter().enumerate() {
let residual = x.clone();
@@ -401,14 +536,13 @@ impl GptOss {
let k_all = add_bias(&matmul_2d(&normed, &layer.k_proj_wt), &layer.k_proj_bias);
let v_all = add_bias(&matmul_2d(&normed, &layer.v_proj_wt), &layer.v_proj_bias);
// Reshape for RoPE: [B, H*D] → [B, H, D]
let q_3d = q_all.reshape(&[batch, num_heads, head_dim]);
let k_3d = k_all.reshape(&[batch, num_kv_heads, head_dim]);
// RoPE (no QK-norm for gpt-oss)
rope_inplace(&q_3d, &self.rope_cache, &positions_u32);
rope_inplace(&k_3d, &self.rope_cache, &positions_u32);
rope_inplace_device_pos(&q_3d, &self.rope_cache, pos_gpu);
rope_inplace_device_pos(&k_3d, &self.rope_cache, pos_gpu);
let v_3d = v_all.reshape(&[batch, num_kv_heads, head_dim]);
@@ -422,9 +556,17 @@ impl GptOss {
let sinks_ptr = layer.sinks.data_ptr() as *const c_void;
let attn_out = paged_decode_attention_sinks(
&q_4d, k_pool_ptr, v_pool_ptr, bt_ptr, cl_ptr,
&q_4d,
k_pool_ptr,
v_pool_ptr,
bt_ptr,
cl_ptr,
sinks_ptr,
batch, num_heads, num_kv_heads, head_dim, max_blocks,
batch,
num_heads,
num_kv_heads,
head_dim,
max_blocks,
layer.window_size,
);
@@ -433,9 +575,14 @@ impl GptOss {
self.all_reduce(&attn_proj);
let attn_proj = add_bias(&attn_proj, &layer.o_proj_bias);
// Residual + post-norm
let (normed, x_new) = Self::add_norm(&attn_proj, &residual, &layer.post_norm, &layer.post_norm_bias, eps);
let (normed, x_new) = Self::add_norm(
&attn_proj,
&residual,
&layer.post_norm,
&layer.post_norm_bias,
eps,
);
let residual = x_new;
let normed = normed.contiguous();
@@ -445,17 +592,8 @@ impl GptOss {
x = xserv_kernels::add(&residual, &moe_out);
}
// Advance KV cache
for &slot in seq_slots {
paged_cache.advance_seq_len(slot, 1);
}
unsafe { xserv_cuda::ffi::cudaDeviceSynchronize(); }
let x = Self::norm(&x, &self.norm, &self.norm_bias, eps);
unsafe { xserv_cuda::ffi::cudaDeviceSynchronize(); }
let logits = matmul_2d(&x, &self.lm_head_t);
unsafe { xserv_cuda::ffi::cudaDeviceSynchronize(); }
logits
matmul_2d(&x, &self.lm_head_t)
}
/// Paged prefill: process full prompt tokens.
@@ -476,7 +614,9 @@ impl GptOss {
paged_cache.advance_seq_len(slot, new_tokens);
let mut x = embedding(&self.embed_tokens, token_ids);
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect();
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens)
.map(|p| p as u32)
.collect();
for (layer_idx, layer) in self.layers.iter().enumerate() {
let residual = x.clone();
@@ -503,14 +643,21 @@ impl GptOss {
let (k_full, v_full) = paged_cache.gather_kv_contiguous(slot, layer_idx);
// Flash attention with gpt-oss sinks + (per-layer) sliding window.
let attn_out = flash_attention_sinks(&q, &k_full, &v_full, &layer.sinks, layer.window_size);
let attn_out =
flash_attention_sinks(&q, &k_full, &v_full, &layer.sinks, layer.window_size);
let attn_merged = merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
self.all_reduce(&attn_proj);
let attn_proj = add_bias(&attn_proj, &layer.o_proj_bias);
let (normed, x_new) = Self::add_norm(&attn_proj, &residual, &layer.post_norm, &layer.post_norm_bias, eps);
let (normed, x_new) = Self::add_norm(
&attn_proj,
&residual,
&layer.post_norm,
&layer.post_norm_bias,
eps,
);
let residual = x_new;
// MoE MLP
@@ -519,9 +666,7 @@ impl GptOss {
}
let x = Self::norm(&x, &self.norm, &self.norm_bias, eps);
let logits = matmul_2d(&x, &self.lm_head_t);
unsafe { xserv_cuda::ffi::cudaDeviceSynchronize(); }
logits
matmul_2d(&x, &self.lm_head_t)
}
/// MoE forward pass — fully on GPU via batched GEMM.
@@ -539,15 +684,101 @@ impl GptOss {
let expert_start = rank * local_experts;
// 1. Router: [tokens, hidden] @ [hidden, num_experts] + bias → [tokens, num_experts]
let router_logits = add_bias(
&matmul_2d(x, &layer.router_wt),
&layer.router_bias,
);
let router_logits = add_bias(&matmul_2d(x, &layer.router_wt), &layer.router_bias);
// 2. GPU top-k + softmax
let (topk_ids, topk_weights) = xserv_kernels::moe::moe_topk_softmax(
&router_logits, num_experts, top_k,
);
let (topk_ids, topk_weights) =
xserv_kernels::moe::moe_topk_softmax(&router_logits, num_experts, top_k);
// Sparse decode path: compute ONLY the routed experts. The dense path
// below reads every local expert's weights per forward; the sparse
// GEMVs read ~top_k/num_experts of the bytes, which dominates decode
// (memory-bound). Dense reads each weight once for ALL tokens, so it
// wins back at num_tokens ≈ local_experts / E[local hits] ≈ 8.
const SPARSE_MAX_TOKENS: usize = 8;
let quantized = layer.expert_gate_up_fp8.is_some() || layer.expert_gate_up_mxfp4.is_some();
if num_tokens <= SPARSE_MAX_TOKENS && quantized && !dense_moe_forced() {
let gate_up = if let Some((ref packed, ref scales)) = layer.expert_gate_up_mxfp4 {
let n = packed.shape()[1];
let k = packed.shape()[2] * 2;
xserv_kernels::moe::moe_sparse_gemv_mxfp4(
x,
packed,
scales,
&layer.expert_gate_up_bias,
&topk_ids,
num_tokens,
top_k,
n,
k,
expert_start,
local_experts,
false,
)
} else {
xserv_kernels::moe::moe_sparse_gemv_fp8(
x,
layer.expert_gate_up_fp8.as_ref().unwrap(),
layer.expert_gate_up_scale.as_ref().unwrap(),
&layer.expert_gate_up_bias,
&topk_ids,
num_tokens,
top_k,
expert_start,
local_experts,
false,
)
};
// GLU over all slots. Non-local slots hold unwritten memory; they
// are never consumed (the down GEMV and the weighted sum both skip
// slots whose expert this rank does not own).
let inter2 = gate_up.shape()[2];
let gate_up_flat = gate_up.reshape(&[num_tokens * top_k, inter2]);
let activated = gpt_oss_glu(&gate_up_flat, layer.glu_alpha, layer.glu_limit);
let down = if let Some((ref packed, ref scales)) = layer.expert_down_mxfp4 {
let n = packed.shape()[1];
let k = packed.shape()[2] * 2;
xserv_kernels::moe::moe_sparse_gemv_mxfp4(
&activated,
packed,
scales,
&layer.expert_down_bias,
&topk_ids,
num_tokens,
top_k,
n,
k,
expert_start,
local_experts,
true,
)
} else {
xserv_kernels::moe::moe_sparse_gemv_fp8(
&activated,
layer.expert_down_fp8.as_ref().unwrap(),
layer.expert_down_scale.as_ref().unwrap(),
&layer.expert_down_bias,
&topk_ids,
num_tokens,
top_k,
expert_start,
local_experts,
true,
)
};
let moe_out = xserv_kernels::moe::moe_weighted_sum_sparse(
&down,
&topk_ids,
&topk_weights,
expert_start,
local_experts,
);
self.all_reduce(&moe_out);
return moe_out;
}
// 3. Replicate input: [tokens, hidden] → [local_experts, tokens, hidden]
let x_rep = xserv_kernels::moe::moe_replicate(x, local_experts);
@@ -563,14 +794,24 @@ impl GptOss {
xserv_kernels::quantization::batched_gemv_mxfp4(&x2, packed, scales, n, k)
.reshape(&[local_experts, 1, n])
} else {
let w_bf16 = xserv_kernels::quantization::dequant_mxfp4_to_bf16_t(packed, scales, local_experts, n, k);
let w_bf16 = xserv_kernels::quantization::dequant_mxfp4_to_bf16_t(
packed,
scales,
local_experts,
n,
k,
);
xserv_kernels::moe::batched_gemm_strided(&x_rep, &w_bf16)
}
} else if let Some(ref wt_fp8_t) = layer.expert_gate_up_fp8 {
// W8A8: quantize activations with per-expert scalar scale, use cuBLASLt FP8 GEMM
let (x_fp8, x_scales) = xserv_kernels::quantization::quantize_bf16_to_fp8_rowwise(&x_rep);
let (x_fp8, x_scales) =
xserv_kernels::quantization::quantize_bf16_to_fp8_rowwise(&x_rep);
xserv_kernels::quantization::batched_gemm_fp8(
&x_fp8, &x_scales, wt_fp8_t, layer.expert_gate_up_scale.as_ref().unwrap(),
&x_fp8,
&x_scales,
wt_fp8_t,
layer.expert_gate_up_scale.as_ref().unwrap(),
)
} else {
xserv_kernels::moe::batched_gemm_strided(&x_rep, &layer.expert_gate_up_wt)
@@ -596,14 +837,24 @@ impl GptOss {
xserv_kernels::quantization::batched_gemv_mxfp4(&a2, packed, scales, n, k)
.reshape(&[local_experts, 1, n])
} else {
let w_bf16 = xserv_kernels::quantization::dequant_mxfp4_to_bf16_t(packed, scales, local_experts, n, k);
let w_bf16 = xserv_kernels::quantization::dequant_mxfp4_to_bf16_t(
packed,
scales,
local_experts,
n,
k,
);
xserv_kernels::moe::batched_gemm_strided(&activated, &w_bf16)
}
} else if let Some(ref wt_fp8) = layer.expert_down_fp8 {
// W8A8: quantize post-GLU activations to FP8, use cuBLASLt FP8 GEMM
let (act_fp8, act_scales) = xserv_kernels::quantization::quantize_bf16_to_fp8_rowwise(&activated);
let (act_fp8, act_scales) =
xserv_kernels::quantization::quantize_bf16_to_fp8_rowwise(&activated);
xserv_kernels::quantization::batched_gemm_fp8(
&act_fp8, &act_scales, wt_fp8, layer.expert_down_scale.as_ref().unwrap(),
&act_fp8,
&act_scales,
wt_fp8,
layer.expert_down_scale.as_ref().unwrap(),
)
} else {
xserv_kernels::moe::batched_gemm_strided(&activated, &layer.expert_down_wt)
@@ -614,8 +865,12 @@ impl GptOss {
// 9. Weighted sum across experts → [tokens, hidden]
let moe_out = xserv_kernels::moe::moe_weighted_sum(
&down, &topk_ids, &topk_weights,
expert_start, local_experts, top_k,
&down,
&topk_ids,
&topk_weights,
expert_start,
local_experts,
top_k,
);
self.all_reduce(&moe_out);
@@ -625,45 +880,45 @@ impl GptOss {
// --- Helpers ---
/// Upload a u32 slice to a pooled GPU buffer (synchronous H2D).
fn upload_u32(vals: &[u32]) -> xserv_cuda::GpuBuffer {
let bytes = unsafe { std::slice::from_raw_parts(vals.as_ptr() as *const u8, vals.len() * 4) };
let mut buf = xserv_cuda::allocator::cached_alloc(bytes.len()).expect("alloc u32 upload");
buf.copy_from_host(bytes).unwrap();
buf
}
/// XSERV_DENSE_MOE=1 forces the dense all-expert path (A/B benchmarking).
fn dense_moe_forced() -> bool {
static FORCED: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
*FORCED.get_or_init(|| std::env::var("XSERV_DENSE_MOE").is_ok_and(|v| v != "0"))
}
fn matmul_2d(a: &Tensor, b: &Tensor) -> Tensor {
assert_eq!(a.ndim(), 2);
assert_eq!(b.ndim(), 2);
matmul(a, b, GemmBackend::CuBlas)
}
/// Add bias to a 2D tensor: [rows, cols] + [cols] → [rows, cols]
/// Add bias to a 2D tensor: [rows, cols] + [cols] → [rows, cols].
/// Single GPU broadcast kernel — the old rows>1 path tiled the bias on the
/// CPU (D2H + host loop + H2D) on every call, 96×/prefill in the hot path.
fn add_bias(x: &Tensor, bias: &Tensor) -> Tensor {
assert_eq!(x.ndim(), 2);
assert_eq!(bias.ndim(), 1);
let rows = x.shape()[0];
let cols = x.shape()[1];
assert_eq!(bias.shape()[0], cols, "bias size {} != cols {}", bias.shape()[0], cols);
let x_c = x.contiguous();
if rows == 1 {
// Fast path: reshape bias [cols] → [1, cols] (zero-copy), add directly on GPU
let bias_2d = bias.reshape(&[1, cols]);
return xserv_kernels::add(&x_c, &bias_2d);
}
// General path: tile bias to [rows, cols] via CPU, then add on GPU
let bias_cpu = bias.to_device(Device::Cpu);
let bias_data = bias_cpu.as_slice::<bf16>();
let mut tiled = Vec::with_capacity(rows * cols);
for _ in 0..rows {
tiled.extend_from_slice(bias_data);
}
let bias_tiled = Tensor::from_slice(&tiled, &[rows, cols]).to_device(x.device());
xserv_kernels::add(&x_c, &bias_tiled)
xserv_kernels::bias_add_2d(&x_c, bias)
}
fn shard_rows(t: &Tensor, rank: usize, world: usize) -> Tensor {
if world == 1 { return t.clone(); }
if world == 1 {
return t.clone();
}
let shape = t.shape();
assert_eq!(shape.len(), 2);
let (rows, cols) = (shape[0], shape[1]);
assert!(rows % world == 0, "rows {rows} not divisible by world {world}");
assert!(
rows % world == 0,
"rows {rows} not divisible by world {world}"
);
let local = rows / world;
let host = t.to_device(Device::Cpu);
let data = host.as_slice::<bf16>();
@@ -673,11 +928,16 @@ fn shard_rows(t: &Tensor, rank: usize, world: usize) -> Tensor {
}
fn shard_cols(t: &Tensor, rank: usize, world: usize) -> Tensor {
if world == 1 { return t.clone(); }
if world == 1 {
return t.clone();
}
let shape = t.shape();
assert_eq!(shape.len(), 2);
let (rows, cols) = (shape[0], shape[1]);
assert!(cols % world == 0, "cols {cols} not divisible by world {world}");
assert!(
cols % world == 0,
"cols {cols} not divisible by world {world}"
);
let local = cols / world;
let c0 = rank * local;
let host = t.to_device(Device::Cpu);
@@ -691,11 +951,16 @@ fn shard_cols(t: &Tensor, rank: usize, world: usize) -> Tensor {
}
fn shard_1d(t: &Tensor, rank: usize, world: usize) -> Tensor {
if world == 1 { return t.clone(); }
if world == 1 {
return t.clone();
}
let shape = t.shape();
assert_eq!(shape.len(), 1);
let total = shape[0];
assert!(total % world == 0, "dim {total} not divisible by world {world}");
assert!(
total % world == 0,
"dim {total} not divisible by world {world}"
);
let local = total / world;
let host = t.to_device(Device::Cpu);
let data = host.as_slice::<bf16>();
@@ -726,7 +991,13 @@ fn transpose_3d_inner_raw(t: &Tensor, batch: usize, rows: usize, cols: usize) ->
}
/// Extract experts [start..start+count) from a [num_experts, rows, cols] 3D tensor (any dtype, raw bytes).
fn slice_expert_range_3d_raw(t: &Tensor, start: usize, count: usize, rows: usize, cols: usize) -> Tensor {
fn slice_expert_range_3d_raw(
t: &Tensor,
start: usize,
count: usize,
rows: usize,
cols: usize,
) -> Tensor {
assert_eq!(t.ndim(), 3);
let host = t.to_device(Device::Cpu);
let elem_size = t.dtype().size_bytes();
@@ -748,7 +1019,13 @@ fn slice_scale_range(t: &Tensor, start: usize, count: usize) -> Tensor {
}
/// Extract experts [start..start+count) from a [num_experts, rows, cols] 3D tensor
fn slice_expert_range_3d(t: &Tensor, start: usize, count: usize, rows: usize, cols: usize) -> Tensor {
fn slice_expert_range_3d(
t: &Tensor,
start: usize,
count: usize,
rows: usize,
cols: usize,
) -> Tensor {
assert_eq!(t.ndim(), 3);
let host = t.to_device(Device::Cpu);
let data = host.as_slice::<bf16>();

View File

@@ -0,0 +1,195 @@
//! CUDA-graph replay for gpt-oss batch=1 decode (Phase 21).
//!
//! A decode step launches ~200 kernels; with sparse MoE the GPU work is only
//! a few ms, so launch overhead dominates TPOT. The whole step (embedding →
//! 24 layers → logits) is captured ONCE into a CUDA graph and replayed per
//! token with a single `cudaGraphLaunch`.
//!
//! Why the existing forward is capturable as-is:
//! - Every per-step variable input lives in a stable-address device buffer
//! whose CONTENTS are updated outside the captured region: token id and
//! position (persistent buffers owned here), block table and context lens
//! (PagedKVCache GPU buffers, refreshed by `decode_prepare`). The KV scatter
//! and paged attention kernels read their write/read positions from those
//! buffers, and the sparse-MoE GEMVs read expert ids from `topk_ids` written
//! earlier in the same graph — all data-dependent, no host branching.
//! - Kernel launches go through the thread-local launch stream
//! (`xserv_cuda::stream::push_stream`), so the capture stream sees them.
//! - Intermediate tensors come from the caching allocator. Blocks freed while
//! capturing are quarantined (`allocator::begin_retain`) for the graph's
//! lifetime so no later allocation can take ownership of memory the graph
//! still references on every replay.
//!
//! Capture preconditions: at least one EAGER decode step must have run first,
//! so the allocator pool already holds every bucket size the step needs
//! (a pool-miss inside capture would call cudaMalloc — illegal while
//! capturing) and cuBLAS has finished its one-time per-shape setup.
use std::ffi::c_void;
use xserv_cuda::allocator::{self, RetainedBlocks};
use xserv_cuda::{CudaGraph, CudaStream, GpuBuffer};
use xserv_tensor::Tensor;
use crate::gpt_oss::GptOss;
use crate::paged_kv_cache::PagedKVCache;
pub struct GptOssDecodeGraph {
stream: CudaStream,
graph: CudaGraph,
ids_buf: GpuBuffer, // [1] u32, persistent graph input
pos_buf: GpuBuffer, // [1] u32, persistent graph input
logits: Tensor, // graph output; rewritten in place by every replay
_arena: RetainedBlocks,
}
impl GptOssDecodeGraph {
/// Capture one batch=1 decode step and replay it once (capture records
/// without executing, so the replay performs this token's computation).
pub fn capture(
model: &GptOss,
token: u32,
position: usize,
slot: usize,
cache: &mut PagedKVCache,
) -> Self {
let stream = CudaStream::new().expect("create capture stream");
let mut ids_buf = allocator::cached_alloc(4).expect("alloc ids buf");
let mut pos_buf = allocator::cached_alloc(4).expect("alloc pos buf");
model.decode_prepare(&[position], &[slot], cache);
ids_buf.copy_from_host(&token.to_le_bytes()).unwrap();
pos_buf
.copy_from_host(&(position as u32).to_le_bytes())
.unwrap();
// Retained warmup: run the exact step once eagerly with the quarantine
// ON. Freed intermediates are held back instead of recycled, so the
// pool ends up stocked with a dedicated block for EVERY allocation the
// step performs. The capture below repeats the same allocation
// sequence and therefore never misses the pool — a pool miss would
// call cudaMalloc, which is illegal while a stream is capturing (this
// is also why one block per bucket is not enough: the capture's own
// quarantine keeps freed blocks out of reuse). Re-running the step is
// idempotent: the KV scatter rewrites the same cache position.
allocator::begin_retain();
{
let _guard = xserv_cuda::push_stream(&stream);
let _ = model.decode_core(
ids_buf.as_ptr() as *const c_void,
pos_buf.as_ptr() as *const c_void,
1,
cache,
);
}
drop(allocator::end_retain()); // release the warmup blocks to the pool
stream.synchronize().expect("warmup sync");
allocator::begin_retain();
let mut graph = CudaGraph::new();
let logits;
{
let _guard = xserv_cuda::stream::push_stream(&stream);
graph
.begin_capture(&stream)
.expect("begin decode-graph capture");
logits = model.decode_core(
ids_buf.as_ptr() as *const c_void,
pos_buf.as_ptr() as *const c_void,
1,
cache,
);
graph
.end_capture(&stream)
.expect("end decode-graph capture");
}
let arena = allocator::end_retain();
graph.launch(&stream).expect("first decode-graph replay");
cache.advance_seq_len(slot, 1);
Self {
stream,
graph,
ids_buf,
pos_buf,
logits,
_arena: arena,
}
}
/// Run one decode step by replaying the captured graph.
pub fn step(
&mut self,
model: &GptOss,
token: u32,
position: usize,
slot: usize,
cache: &mut PagedKVCache,
) -> Tensor {
model.decode_prepare(&[position], &[slot], cache);
self.ids_buf.copy_from_host(&token.to_le_bytes()).unwrap();
self.pos_buf
.copy_from_host(&(position as u32).to_le_bytes())
.unwrap();
self.graph
.launch(&self.stream)
.expect("decode-graph replay");
cache.advance_seq_len(slot, 1);
// Shallow clone: the caller reads these logits before the next replay
// rewrites the underlying buffer.
self.logits.clone()
}
}
/// Lazy capture policy: first decode step of the process runs eager (warms the
/// allocator pool + cuBLAS so capture performs no "unsafe" CUDA calls), the
/// second is captured, the rest replay. Batch>1 always falls back to eager.
/// Disable with XSERV_DECODE_GRAPH=0.
pub struct GraphedGptOssDecoder {
graph: Option<GptOssDecodeGraph>,
eager_steps: u32,
enabled: bool,
}
impl GraphedGptOssDecoder {
pub fn new() -> Self {
let enabled = std::env::var("XSERV_DECODE_GRAPH")
.map(|v| v != "0")
.unwrap_or(true);
Self {
graph: None,
eager_steps: 0,
enabled,
}
}
pub fn decode(
&mut self,
model: &GptOss,
tokens: &[u32],
positions: &[usize],
slots: &[usize],
cache: &mut PagedKVCache,
) -> Tensor {
if self.enabled && tokens.len() == 1 {
if let Some(g) = self.graph.as_mut() {
return g.step(model, tokens[0], positions[0], slots[0], cache);
}
if self.eager_steps >= 1 {
let g = GptOssDecodeGraph::capture(model, tokens[0], positions[0], slots[0], cache);
let logits = g.logits.clone();
self.graph = Some(g);
return logits;
}
}
self.eager_steps += 1;
model.forward_decode_paged(tokens, positions, slots, cache)
}
}
impl Default for GraphedGptOssDecoder {
fn default() -> Self {
Self::new()
}
}

View File

@@ -1,6 +1,6 @@
use xserv_cuda::GpuBuffer;
use xserv_tensor::{DType, Device, Tensor};
use crate::config::ModelConfig;
use xserv_cuda::GpuBuffer;
use xserv_tensor::{DType, Tensor};
/// GPU-resident KV cache. Pre-allocates max_seq_len on GPU,
/// appends new K/V via D2D copy at offset (no CPU round-trip).
@@ -46,17 +46,43 @@ impl GpuKVCache {
v_staging.push(GpuBuffer::alloc(buf_size).expect("alloc KV staging V"));
}
Self { k_bufs, v_bufs, k_staging, v_staging, seq_len: 0, max_seq_len, num_kv_heads, head_dim, elem_size, dtype, device }
Self {
k_bufs,
v_bufs,
k_staging,
v_staging,
seq_len: 0,
max_seq_len,
num_kv_heads,
head_dim,
elem_size,
dtype,
device,
}
}
pub fn seq_len(&self) -> usize { self.seq_len }
pub fn max_seq_len(&self) -> usize { self.max_seq_len }
pub fn seq_len(&self) -> usize {
self.seq_len
}
pub fn max_seq_len(&self) -> usize {
self.max_seq_len
}
/// Append new K/V tensors for a given layer.
/// k_new, v_new: [1, num_kv_heads, new_tokens, head_dim] on GPU, contiguous.
/// `write_pos` is the sequence position to write at (caller manages this).
pub fn append(&mut self, layer: usize, k_new: &Tensor, v_new: &Tensor, new_tokens: usize, write_pos: usize) {
assert!(write_pos + new_tokens <= self.max_seq_len, "KV cache overflow");
pub fn append(
&mut self,
layer: usize,
k_new: &Tensor,
v_new: &Tensor,
new_tokens: usize,
write_pos: usize,
) {
assert!(
write_pos + new_tokens <= self.max_seq_len,
"KV cache overflow"
);
let es = self.elem_size;
let hd = self.head_dim;
let max_s = self.max_seq_len;
@@ -69,14 +95,23 @@ impl GpuKVCache {
let src_off = h * new_tokens * hd * es;
let dst_off = (h * max_s + write_pos) * hd * es;
let count = new_tokens * hd * es;
self.k_bufs[layer].copy_from_device_at(k_src, src_off, dst_off, count).unwrap();
self.v_bufs[layer].copy_from_device_at(v_src, src_off, dst_off, count).unwrap();
self.k_bufs[layer]
.copy_from_device_at(k_src, src_off, dst_off, count)
.unwrap();
self.v_bufs[layer]
.copy_from_device_at(v_src, src_off, dst_off, count)
.unwrap();
}
}
pub fn advance_seq_len(&mut self, new_tokens: usize) {
self.seq_len += new_tokens;
assert!(self.seq_len <= self.max_seq_len, "KV cache seq_len ({}) exceeds max_seq_len ({})", self.seq_len, self.max_seq_len);
assert!(
self.seq_len <= self.max_seq_len,
"KV cache seq_len ({}) exceeds max_seq_len ({})",
self.seq_len,
self.max_seq_len
);
}
/// Get K/V cache tensors for a layer up to `seq_len` tokens: [1, num_kv_heads, seq_len, head_dim]
@@ -86,7 +121,11 @@ impl GpuKVCache {
}
pub fn get_kv_len(&mut self, layer: usize, sl: usize) -> (Tensor, Tensor) {
assert!(sl <= self.max_seq_len, "get_kv_len: sl ({sl}) exceeds max_seq_len ({})", self.max_seq_len);
assert!(
sl <= self.max_seq_len,
"get_kv_len: sl ({sl}) exceeds max_seq_len ({})",
self.max_seq_len
);
let hd = self.head_dim;
let nh = self.num_kv_heads;
let es = self.elem_size;
@@ -104,8 +143,12 @@ impl GpuKVCache {
let src_off = (h * max_s) * hd * es;
let dst_off = (h * sl) * hd * es;
let count = sl * hd * es;
k_stg.copy_from_device_at(k_buf, src_off, dst_off, count).unwrap();
v_stg.copy_from_device_at(v_buf, src_off, dst_off, count).unwrap();
k_stg
.copy_from_device_at(k_buf, src_off, dst_off, count)
.unwrap();
v_stg
.copy_from_device_at(v_buf, src_off, dst_off, count)
.unwrap();
}
// Grab raw pointers before dropping the mutable borrows
let k_ptr = k_stg.as_mut_ptr();
@@ -117,20 +160,35 @@ impl GpuKVCache {
// get_kv_len call overwrites the staging buffer).
let shape = &[1usize, nh, sl, hd];
let k = unsafe {
tensor_from_gpu_buffer(GpuBuffer::borrow_raw(k_ptr, out_size), shape, self.dtype, self.device)
tensor_from_gpu_buffer(
GpuBuffer::borrow_raw(k_ptr, out_size),
shape,
self.dtype,
self.device,
)
};
let v = unsafe {
tensor_from_gpu_buffer(GpuBuffer::borrow_raw(v_ptr, out_size), shape, self.dtype, self.device)
tensor_from_gpu_buffer(
GpuBuffer::borrow_raw(v_ptr, out_size),
shape,
self.dtype,
self.device,
)
};
(k, v)
}
}
/// Create a Tensor from a GpuBuffer (takes ownership).
unsafe fn tensor_from_gpu_buffer(buf: GpuBuffer, shape: &[usize], dtype: DType, device: u32) -> Tensor {
use xserv_tensor::storage::Storage;
use xserv_tensor::shape::contiguous_strides;
unsafe fn tensor_from_gpu_buffer(
buf: GpuBuffer,
shape: &[usize],
dtype: DType,
device: u32,
) -> Tensor {
use smallvec::SmallVec;
use xserv_tensor::shape::contiguous_strides;
use xserv_tensor::storage::Storage;
let storage = Storage::cuda(buf, device);
Tensor::from_storage(
@@ -146,6 +204,11 @@ unsafe fn tensor_from_gpu_buffer(buf: GpuBuffer, shape: &[usize], dtype: DType,
///
/// # Safety
/// `buf` must be a valid GPU allocation with at least `product(shape) * dtype.size_bytes()` bytes.
pub unsafe fn tensor_from_gpu_buffer_pub(buf: GpuBuffer, shape: &[usize], dtype: DType, device: u32) -> Tensor {
pub unsafe fn tensor_from_gpu_buffer_pub(
buf: GpuBuffer,
shape: &[usize],
dtype: DType,
device: u32,
) -> Tensor {
tensor_from_gpu_buffer(buf, shape, dtype, device)
}

View File

@@ -2,6 +2,7 @@ pub mod config;
pub mod decode_graph;
pub mod gpt2;
pub mod gpt_oss;
pub mod gpt_oss_graph;
pub mod kv_cache;
pub mod loader;
pub mod paged_kv_cache;
@@ -10,10 +11,11 @@ pub mod sampling;
pub use config::ModelConfig;
pub use decode_graph::{DecodeGraphState, LayerWeightPtrs};
pub use gpt2::{GPT2, KVCache};
pub use gpt_oss::GptOss;
pub use gpt_oss_graph::{GptOssDecodeGraph, GraphedGptOssDecoder};
pub use gpt2::{GPT2, KVCache};
pub use kv_cache::GpuKVCache;
pub use paged_kv_cache::{BlockAllocator, Location, PagedKVCache, BLOCK_SIZE};
pub use paged_kv_cache::{BLOCK_SIZE, BlockAllocator, Location, PagedKVCache};
pub use qwen3::Qwen3;
pub use sampling::{SamplingParams, sample, sample_greedy_penalized};

View File

@@ -5,8 +5,8 @@ use std::path::Path;
use xserv_tensor::{DType, Device, Tensor};
pub fn load_safetensors(path: &Path, device: Device) -> HashMap<String, Tensor> {
let data = std::fs::read(path)
.unwrap_or_else(|e| panic!("failed to read {}: {e}", path.display()));
let data =
std::fs::read(path).unwrap_or_else(|e| panic!("failed to read {}: {e}", path.display()));
let st = SafeTensors::deserialize(&data)
.unwrap_or_else(|e| panic!("failed to parse safetensors {}: {e}", path.display()));
@@ -60,7 +60,11 @@ pub fn load_model_dir(dir: &Path, device: Device) -> HashMap<String, Tensor> {
all_tensors.extend(tensors);
}
assert!(!all_tensors.is_empty(), "no safetensors files found in {}", dir.display());
assert!(
!all_tensors.is_empty(),
"no safetensors files found in {}",
dir.display()
);
all_tensors
}
@@ -84,8 +88,6 @@ fn make_tensor(raw_bytes: &[u8], shape: &[usize], dtype: DType) -> Tensor {
};
Tensor::from_slice(bfs, shape)
}
DType::FP8E4M3 => {
Tensor::from_raw_bytes(raw_bytes, shape, DType::FP8E4M3)
}
DType::FP8E4M3 => Tensor::from_raw_bytes(raw_bytes, shape, DType::FP8E4M3),
}
}

View File

@@ -29,7 +29,10 @@ impl BlockAllocator {
for b in (1..total_blocks).rev() {
free_stack.push(b as u32);
}
Self { free_stack, total: total_blocks }
Self {
free_stack,
total: total_blocks,
}
}
pub fn alloc(&mut self) -> Option<u32> {
@@ -136,8 +139,14 @@ impl PagedKVCache {
device: u32,
) -> Self {
Self::new_tp(
config, config.num_kv_heads(), total_blocks, cpu_total_blocks,
max_seqs, max_blocks_per_seq, dtype, device,
config,
config.num_kv_heads(),
total_blocks,
cpu_total_blocks,
max_seqs,
max_blocks_per_seq,
dtype,
device,
)
}
@@ -155,7 +164,10 @@ impl PagedKVCache {
dtype: DType,
device: u32,
) -> Self {
assert!(total_blocks >= 2, "need at least 2 blocks (one is sentinel)");
assert!(
total_blocks >= 2,
"need at least 2 blocks (one is sentinel)"
);
let num_layers = config.num_layers();
let head_dim = config.head_dim();
let elem_size = dtype.size_bytes();
@@ -179,11 +191,17 @@ impl PagedKVCache {
if cpu_total_blocks >= 2 {
let cpu_pool_bytes = cpu_total_blocks * block_bytes;
for _ in 0..num_layers {
cpu_k_pools.push(PinnedBuffer::alloc(cpu_pool_bytes).expect("alloc CPU K swap pool"));
cpu_v_pools.push(PinnedBuffer::alloc(cpu_pool_bytes).expect("alloc CPU V swap pool"));
cpu_k_pools
.push(PinnedBuffer::alloc(cpu_pool_bytes).expect("alloc CPU K swap pool"));
cpu_v_pools
.push(PinnedBuffer::alloc(cpu_pool_bytes).expect("alloc CPU V swap pool"));
}
}
let cpu_allocator = BlockAllocator::new(if cpu_total_blocks >= 2 { cpu_total_blocks } else { 0 });
let cpu_allocator = BlockAllocator::new(if cpu_total_blocks >= 2 {
cpu_total_blocks
} else {
0
});
let block_table_gpu =
GpuBuffer::alloc(max_seqs * max_blocks_per_seq * std::mem::size_of::<i32>())
@@ -220,22 +238,49 @@ impl PagedKVCache {
}
}
pub fn num_layers(&self) -> usize { self.num_layers }
pub fn num_kv_heads(&self) -> usize { self.num_kv_heads }
pub fn head_dim(&self) -> usize { self.head_dim }
pub fn dtype(&self) -> DType { self.dtype }
pub fn max_seqs(&self) -> usize { self.max_seqs }
pub fn max_blocks_per_seq(&self) -> usize { self.max_blocks_per_seq }
pub fn free_blocks(&self) -> usize { self.allocator.free_count() }
pub fn total_blocks(&self) -> usize { self.allocator.total() }
pub fn num_layers(&self) -> usize {
self.num_layers
}
pub fn num_kv_heads(&self) -> usize {
self.num_kv_heads
}
pub fn head_dim(&self) -> usize {
self.head_dim
}
pub fn dtype(&self) -> DType {
self.dtype
}
pub fn max_seqs(&self) -> usize {
self.max_seqs
}
pub fn max_blocks_per_seq(&self) -> usize {
self.max_blocks_per_seq
}
pub fn free_blocks(&self) -> usize {
self.allocator.free_count()
}
pub fn total_blocks(&self) -> usize {
self.allocator.total()
}
pub fn k_pool(&self, layer: usize) -> &GpuBuffer { &self.k_pools[layer] }
pub fn v_pool(&self, layer: usize) -> &GpuBuffer { &self.v_pools[layer] }
pub fn block_table_gpu(&self) -> &GpuBuffer { &self.block_table_gpu }
pub fn context_lens_gpu(&self) -> &GpuBuffer { &self.context_lens_gpu }
pub fn k_pool(&self, layer: usize) -> &GpuBuffer {
&self.k_pools[layer]
}
pub fn v_pool(&self, layer: usize) -> &GpuBuffer {
&self.v_pools[layer]
}
pub fn block_table_gpu(&self) -> &GpuBuffer {
&self.block_table_gpu
}
pub fn context_lens_gpu(&self) -> &GpuBuffer {
&self.context_lens_gpu
}
pub fn seq_len(&self, slot: usize) -> usize {
self.seq_states[slot].as_ref().map(|s| s.seq_len).unwrap_or(0)
self.seq_states[slot]
.as_ref()
.map(|s| s.seq_len)
.unwrap_or(0)
}
pub fn is_slot_free(&self, slot: usize) -> bool {
@@ -280,7 +325,11 @@ impl PagedKVCache {
let state = self.seq_states[slot].as_ref().expect("unregistered slot");
let cur = state.block_ids.len();
let needed_total = (state.seq_len + new_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE;
if needed_total > cur { needed_total - cur } else { 0 }
if needed_total > cur {
needed_total - cur
} else {
0
}
}
/// Pre-allocate enough physical blocks in `slot` to cover positions
@@ -290,8 +339,14 @@ impl PagedKVCache {
let state = self.seq_states[slot].as_mut().expect("unregistered slot");
let needed_total = (end_pos + BLOCK_SIZE - 1) / BLOCK_SIZE;
while state.block_ids.len() < needed_total {
let b = self.allocator.alloc().expect("out of blocks (caller must check)");
assert!(state.block_ids.len() < self.max_blocks_per_seq, "block table overflow");
let b = self
.allocator
.alloc()
.expect("out of blocks (caller must check)");
assert!(
state.block_ids.len() < self.max_blocks_per_seq,
"block table overflow"
);
state.block_ids.push(b);
}
}
@@ -318,7 +373,9 @@ impl PagedKVCache {
num_tokens: usize,
start_pos: usize,
) {
if num_tokens == 0 { return; }
if num_tokens == 0 {
return;
}
// Make sure blocks exist for the target range.
self.ensure_capacity(slot, start_pos + num_tokens);
@@ -328,15 +385,21 @@ impl PagedKVCache {
// Stage block_ids on the GPU. Pool-allocated so this is essentially
// free after the first call (same bucket every step).
let block_ids: Vec<i32> = self.seq_states[slot].as_ref().unwrap()
.block_ids.iter().map(|&b| b as i32).collect();
let block_ids: Vec<i32> = self.seq_states[slot]
.as_ref()
.unwrap()
.block_ids
.iter()
.map(|&b| b as i32)
.collect();
let bytes = block_ids.len() * std::mem::size_of::<i32>();
let mut block_ids_gpu = xserv_cuda::allocator::cached_alloc(bytes)
.expect("alloc append block_ids");
let block_ids_bytes = unsafe {
std::slice::from_raw_parts(block_ids.as_ptr() as *const u8, bytes)
};
block_ids_gpu.copy_from_host(block_ids_bytes).expect("upload block_ids");
let mut block_ids_gpu =
xserv_cuda::allocator::cached_alloc(bytes).expect("alloc append block_ids");
let block_ids_bytes =
unsafe { std::slice::from_raw_parts(block_ids.as_ptr() as *const u8, bytes) };
block_ids_gpu
.copy_from_host(block_ids_bytes)
.expect("upload block_ids");
let k_src = k_new.data_ptr() as *const std::ffi::c_void;
let v_src = v_new.data_ptr() as *const std::ffi::c_void;
@@ -345,11 +408,17 @@ impl PagedKVCache {
unsafe {
xserv_kernels::reshape_and_cache_bf16(
k_src, v_src,
k_pool_ptr, v_pool_ptr,
k_src,
v_src,
k_pool_ptr,
v_pool_ptr,
block_ids_gpu.as_ptr() as *const i32,
num_tokens, nkv, hd, start_pos, bs,
std::ptr::null_mut(),
num_tokens,
nkv,
hd,
start_pos,
bs,
xserv_cuda::current_stream_raw(),
);
}
// block_ids_gpu drops here; the launch on the null stream will have
@@ -378,7 +447,9 @@ impl PagedKVCache {
v_new: &Tensor,
batch: usize,
) {
if batch == 0 { return; }
if batch == 0 {
return;
}
let nkv = self.num_kv_heads;
let hd = self.head_dim;
debug_assert_eq!(k_new.shape(), &[batch, nkv, hd]);
@@ -393,11 +464,18 @@ impl PagedKVCache {
unsafe {
xserv_kernels::reshape_and_cache_batched_bf16(
k_src, v_src,
k_pool_ptr, v_pool_ptr,
bt_ptr, cl_ptr,
batch, nkv, hd, BLOCK_SIZE, self.max_blocks_per_seq,
std::ptr::null_mut(),
k_src,
v_src,
k_pool_ptr,
v_pool_ptr,
bt_ptr,
cl_ptr,
batch,
nkv,
hd,
BLOCK_SIZE,
self.max_blocks_per_seq,
xserv_cuda::current_stream_raw(),
);
}
}
@@ -447,7 +525,10 @@ impl PagedKVCache {
/// before advance_seq_len has run).
pub fn sync_active_batch_with_lens(&mut self, slots: &[usize], kv_lens: &[i32]) {
assert_eq!(slots.len(), kv_lens.len());
assert!(slots.len() <= self.max_seqs, "active batch exceeds max_seqs");
assert!(
slots.len() <= self.max_seqs,
"active batch exceeds max_seqs"
);
let stride = self.max_blocks_per_seq;
for row in &mut self.block_table_host {
*row = 0;
@@ -456,7 +537,9 @@ impl PagedKVCache {
*cl = 0;
}
for (i, &slot) in slots.iter().enumerate() {
let s = self.seq_states[slot].as_ref().expect("unregistered slot in active batch");
let s = self.seq_states[slot]
.as_ref()
.expect("unregistered slot in active batch");
let row = &mut self.block_table_host[i * stride..(i + 1) * stride];
for (j, b) in s.block_ids.iter().enumerate() {
row[j] = *b as i32;
@@ -515,8 +598,12 @@ impl PagedKVCache {
let src_off = ((phys * nkv + h) * bs + slot_in_blk) * hd * es;
let dst_off = (h * sl + p) * hd * es;
let count = chunk * hd * es;
k_dst.copy_from_device_at(k_pool, src_off, dst_off, count).unwrap();
v_dst.copy_from_device_at(v_pool, src_off, dst_off, count).unwrap();
k_dst
.copy_from_device_at(k_pool, src_off, dst_off, count)
.unwrap();
v_dst
.copy_from_device_at(v_pool, src_off, dst_off, count)
.unwrap();
}
p += chunk;
}
@@ -529,16 +616,26 @@ impl PagedKVCache {
// ----- Swapping (vLLM-style preemption to pinned host memory) -----
pub fn free_cpu_blocks(&self) -> usize { self.cpu_allocator.free_count() }
pub fn swap_enabled(&self) -> bool { !self.cpu_k_pools.is_empty() }
pub fn free_cpu_blocks(&self) -> usize {
self.cpu_allocator.free_count()
}
pub fn swap_enabled(&self) -> bool {
!self.cpu_k_pools.is_empty()
}
pub fn is_swapped(&self, slot: usize) -> bool {
matches!(self.seq_states[slot].as_ref().map(|s| s.location), Some(Location::Cpu))
matches!(
self.seq_states[slot].as_ref().map(|s| s.location),
Some(Location::Cpu)
)
}
/// Number of physical blocks currently held by `slot` (in either pool).
pub fn block_count(&self, slot: usize) -> usize {
self.seq_states[slot].as_ref().map(|s| s.block_ids.len()).unwrap_or(0)
self.seq_states[slot]
.as_ref()
.map(|s| s.block_ids.len())
.unwrap_or(0)
}
/// Whether a swapped sequence at `slot` can be brought back (enough free GPU blocks).
@@ -554,11 +651,17 @@ impl PagedKVCache {
/// Evict `slot`'s KV from GPU to pinned host memory and free its GPU blocks.
/// The slot stays registered (location = Cpu); the sequence is paused.
pub fn swap_out(&mut self, slot: usize) -> Result<(), &'static str> {
let state = self.seq_states[slot].as_ref().ok_or("swap_out: empty slot")?;
if state.location == Location::Cpu { return Ok(()); }
let state = self.seq_states[slot]
.as_ref()
.ok_or("swap_out: empty slot")?;
if state.location == Location::Cpu {
return Ok(());
}
let gpu_ids = state.block_ids.clone();
let n = gpu_ids.len();
if !self.cpu_allocator.can_alloc(n) { return Err("swap_out: CPU pool full"); }
if !self.cpu_allocator.can_alloc(n) {
return Err("swap_out: CPU pool full");
}
let cpu_ids: Vec<u32> = (0..n)
.map(|_| self.cpu_allocator.alloc().expect("checked can_alloc"))
@@ -570,10 +673,18 @@ impl PagedKVCache {
let g_off = gpu_ids[i] as usize * bb;
let c_off = cpu_ids[i] as usize * bb;
self.k_pools[layer]
.copy_to_host_at(&mut self.cpu_k_pools[layer].as_mut_slice()[c_off..c_off + bb], g_off, bb)
.copy_to_host_at(
&mut self.cpu_k_pools[layer].as_mut_slice()[c_off..c_off + bb],
g_off,
bb,
)
.unwrap();
self.v_pools[layer]
.copy_to_host_at(&mut self.cpu_v_pools[layer].as_mut_slice()[c_off..c_off + bb], g_off, bb)
.copy_to_host_at(
&mut self.cpu_v_pools[layer].as_mut_slice()[c_off..c_off + bb],
g_off,
bb,
)
.unwrap();
}
}
@@ -589,11 +700,17 @@ impl PagedKVCache {
/// Bring `slot`'s KV back from host to GPU and free its CPU blocks.
pub fn swap_in(&mut self, slot: usize) -> Result<(), &'static str> {
let state = self.seq_states[slot].as_ref().ok_or("swap_in: empty slot")?;
if state.location == Location::Gpu { return Ok(()); }
let state = self.seq_states[slot]
.as_ref()
.ok_or("swap_in: empty slot")?;
if state.location == Location::Gpu {
return Ok(());
}
let cpu_ids = state.block_ids.clone();
let n = cpu_ids.len();
if !self.allocator.can_alloc(n) { return Err("swap_in: GPU pool full"); }
if !self.allocator.can_alloc(n) {
return Err("swap_in: GPU pool full");
}
let gpu_ids: Vec<u32> = (0..n)
.map(|_| self.allocator.alloc().expect("checked can_alloc"))
@@ -605,10 +722,18 @@ impl PagedKVCache {
let g_off = gpu_ids[i] as usize * bb;
let c_off = cpu_ids[i] as usize * bb;
self.k_pools[layer]
.copy_from_host_at(&self.cpu_k_pools[layer].as_slice()[c_off..c_off + bb], g_off, bb)
.copy_from_host_at(
&self.cpu_k_pools[layer].as_slice()[c_off..c_off + bb],
g_off,
bb,
)
.unwrap();
self.v_pools[layer]
.copy_from_host_at(&self.cpu_v_pools[layer].as_slice()[c_off..c_off + bb], g_off, bb)
.copy_from_host_at(
&self.cpu_v_pools[layer].as_slice()[c_off..c_off + bb],
g_off,
bb,
)
.unwrap();
}
}
@@ -623,7 +748,12 @@ impl PagedKVCache {
}
}
unsafe fn tensor_from_owned_buf(buf: GpuBuffer, shape: &[usize], dtype: DType, device: u32) -> Tensor {
unsafe fn tensor_from_owned_buf(
buf: GpuBuffer,
shape: &[usize],
dtype: DType,
device: u32,
) -> Tensor {
use smallvec::SmallVec;
use xserv_tensor::shape::contiguous_strides;
use xserv_tensor::storage::Storage;

View File

@@ -1,7 +1,7 @@
use std::collections::HashMap;
use half::bf16;
use std::collections::HashMap;
use xserv_kernels::*;
use xserv_tensor::{DType, Device, Tensor};
use xserv_tensor::{Device, Tensor};
use crate::config::ModelConfig;
use crate::gpt2::KVCache;
@@ -13,7 +13,7 @@ pub struct Qwen3 {
embed_tokens: Tensor,
layers: Vec<Qwen3Block>,
norm: Tensor,
lm_head_t: Tensor, // precomputed transpose
lm_head_t: Tensor, // precomputed transpose
rope_cache: RopeCache,
// Tensor parallelism. `tp` is None (or world==1) for single-GPU; otherwise
// this rank holds 1/world of the heads and AllReduces after o_proj/down_proj.
@@ -28,22 +28,29 @@ pub struct Qwen3 {
}
struct Qwen3Block {
input_norm: Tensor, // [hidden]
input_norm: Tensor, // [hidden]
qkv_proj_wt: Tensor, // FUSED: [hidden, (H+2*KV)*D] — Q|K|V columns
q_dim: usize, // num_heads * head_dim (Q slice boundary)
kv_dim: usize, // num_kv_heads * head_dim (K/V slice size)
o_proj_wt: Tensor, // TRANSPOSED: [num_heads*head_dim, hidden]
q_norm: Tensor, // [head_dim]
k_norm: Tensor, // [head_dim]
post_norm: Tensor, // [hidden]
gate_up_proj_wt: Tensor, // FUSED: [hidden, 2*intermediate]
down_proj_wt: Tensor, // TRANSPOSED: [intermediate, hidden]
o_proj_wt: Tensor, // TRANSPOSED: [num_heads*head_dim, hidden]
q_norm: Tensor, // [head_dim]
k_norm: Tensor, // [head_dim]
post_norm: Tensor, // [hidden]
gate_up_proj_wt: Tensor, // FUSED: [hidden, 2*intermediate]
down_proj_wt: Tensor, // TRANSPOSED: [intermediate, hidden]
}
impl Qwen3Block {
fn q_proj_wt(&self) -> Tensor { self.qkv_proj_wt.narrow(1, 0, self.q_dim) }
fn k_proj_wt(&self) -> Tensor { self.qkv_proj_wt.narrow(1, self.q_dim, self.kv_dim) }
fn v_proj_wt(&self) -> Tensor { self.qkv_proj_wt.narrow(1, self.q_dim + self.kv_dim, self.kv_dim) }
fn q_proj_wt(&self) -> Tensor {
self.qkv_proj_wt.narrow(1, 0, self.q_dim)
}
fn k_proj_wt(&self) -> Tensor {
self.qkv_proj_wt.narrow(1, self.q_dim, self.kv_dim)
}
fn v_proj_wt(&self) -> Tensor {
self.qkv_proj_wt
.narrow(1, self.q_dim + self.kv_dim, self.kv_dim)
}
fn gate_proj_wt(&self) -> Tensor {
let half = self.gate_up_proj_wt.shape()[1] / 2;
self.gate_up_proj_wt.narrow(1, 0, half)
@@ -80,18 +87,31 @@ impl Qwen3 {
crate::init_kernels();
let dev = Device::Cuda(device);
let take = |w: &mut HashMap<String, Tensor>, name: &str| -> Tensor {
w.remove(name).unwrap_or_else(|| panic!("missing weight: {name}"))
w.remove(name)
.unwrap_or_else(|| panic!("missing weight: {name}"))
};
// Replicated weight: upload whole to this rank's device.
let repl = |t: Tensor| -> Tensor { t.to_device(dev) };
// column-parallel: keep this rank's rows of [out, in], upload, transpose → [in, out/world].
let col = |t: Tensor| -> Tensor { shard_rows(&t, rank, world).to_device(dev).transpose(0, 1).contiguous() };
let col = |t: Tensor| -> Tensor {
shard_rows(&t, rank, world)
.to_device(dev)
.transpose(0, 1)
.contiguous()
};
// row-parallel: keep this rank's cols of [out, in], upload, transpose → [in/world, out].
let row = |t: Tensor| -> Tensor { shard_cols(&t, rank, world).to_device(dev).transpose(0, 1).contiguous() };
let row = |t: Tensor| -> Tensor {
shard_cols(&t, rank, world)
.to_device(dev)
.transpose(0, 1)
.contiguous()
};
let embed_tokens = repl(take(&mut w, "model.embed_tokens.weight"));
let norm = repl(take(&mut w, "model.norm.weight"));
let lm_head_t = repl(take(&mut w, "lm_head.weight")).transpose(0, 1).contiguous();
let lm_head_t = repl(take(&mut w, "lm_head.weight"))
.transpose(0, 1)
.contiguous();
let rope_cache = RopeCache::new(
config.max_seq_len(),
@@ -102,7 +122,10 @@ impl Qwen3 {
let num_layers = config.num_layers();
let mut layers = Vec::with_capacity(num_layers);
if rank == 0 {
eprintln!("Loading+sharding weights for {} layers (world={world})...", num_layers);
eprintln!(
"Loading+sharding weights for {} layers (world={world})...",
num_layers
);
}
for i in 0..num_layers {
let p = format!("model.layers.{i}");
@@ -126,7 +149,10 @@ impl Qwen3 {
o_proj_wt: row(take(&mut w, &format!("{p}.self_attn.o_proj.weight"))),
q_norm: repl(take(&mut w, &format!("{p}.self_attn.q_norm.weight"))),
k_norm: repl(take(&mut w, &format!("{p}.self_attn.k_norm.weight"))),
post_norm: repl(take(&mut w, &format!("{p}.post_attention_layernorm.weight"))),
post_norm: repl(take(
&mut w,
&format!("{p}.post_attention_layernorm.weight"),
)),
gate_up_proj_wt,
down_proj_wt: row(take(&mut w, &format!("{p}.mlp.down_proj.weight"))),
});
@@ -165,7 +191,10 @@ impl Qwen3 {
let dev = Device::Cuda(device);
assert!(num_stages >= 1);
let num_layers = config.num_layers();
assert!(num_layers % num_stages == 0, "num_layers {num_layers} not divisible by pp {num_stages}");
assert!(
num_layers % num_stages == 0,
"num_layers {num_layers} not divisible by pp {num_stages}"
);
let per_stage = num_layers / num_stages;
let lo = stage * per_stage;
let hi = lo + per_stage;
@@ -173,16 +202,29 @@ impl Qwen3 {
let is_last_stage = stage == num_stages - 1;
let take = |w: &mut HashMap<String, Tensor>, name: &str| -> Tensor {
w.remove(name).unwrap_or_else(|| panic!("missing weight: {name}"))
w.remove(name)
.unwrap_or_else(|| panic!("missing weight: {name}"))
};
let repl = |t: Tensor| -> Tensor { t.to_device(dev) };
// Pre-transpose like the TP path's `col`/`row` do for world==1 (no shard).
let wt = |t: Tensor| -> Tensor { t.to_device(dev).transpose(0, 1).contiguous() };
let placeholder = || Tensor::from_slice(&[bf16::ZERO], &[1, 1]).to_device(dev);
let embed_tokens = if is_first_stage { repl(take(&mut w, "model.embed_tokens.weight")) } else { placeholder() };
let norm = if is_last_stage { repl(take(&mut w, "model.norm.weight")) } else { placeholder() };
let lm_head_t = if is_last_stage { wt(take(&mut w, "lm_head.weight")) } else { placeholder() };
let embed_tokens = if is_first_stage {
repl(take(&mut w, "model.embed_tokens.weight"))
} else {
placeholder()
};
let norm = if is_last_stage {
repl(take(&mut w, "model.norm.weight"))
} else {
placeholder()
};
let lm_head_t = if is_last_stage {
wt(take(&mut w, "lm_head.weight"))
} else {
placeholder()
};
let rope_cache = RopeCache::new(
config.max_seq_len(),
@@ -217,7 +259,10 @@ impl Qwen3 {
o_proj_wt: wt(take(&mut w, &format!("{p}.self_attn.o_proj.weight"))),
q_norm: repl(take(&mut w, &format!("{p}.self_attn.q_norm.weight"))),
k_norm: repl(take(&mut w, &format!("{p}.self_attn.k_norm.weight"))),
post_norm: repl(take(&mut w, &format!("{p}.post_attention_layernorm.weight"))),
post_norm: repl(take(
&mut w,
&format!("{p}.post_attention_layernorm.weight"),
)),
gate_up_proj_wt,
down_proj_wt: wt(take(&mut w, &format!("{p}.mlp.down_proj.weight"))),
});
@@ -252,8 +297,12 @@ impl Qwen3 {
matmul_2d(&x, &self.lm_head_t)
}
pub fn pp_is_first(&self) -> bool { self.is_first_stage }
pub fn pp_is_last(&self) -> bool { self.is_last_stage }
pub fn pp_is_first(&self) -> bool {
self.is_first_stage
}
pub fn pp_is_last(&self) -> bool {
self.is_last_stage
}
/// PP prefill over THIS stage's layers. `x` is `[S, hidden]` (stage 0: from
/// `embed`; otherwise received from the previous stage). Writes K/V for this
@@ -276,7 +325,9 @@ impl Qwen3 {
paged_cache.ensure_capacity(slot, pos_offset + new_tokens);
paged_cache.advance_seq_len(slot, new_tokens);
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect();
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens)
.map(|p| p as u32)
.collect();
for (layer_idx, layer) in self.layers.iter().enumerate() {
let residual = x.clone();
@@ -285,7 +336,9 @@ impl Qwen3 {
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt);
let q = qkv.narrow(1, 0, layer.q_dim).contiguous();
let k = qkv.narrow(1, layer.q_dim, layer.kv_dim).contiguous();
let v = qkv.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous();
let v = qkv
.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim)
.contiguous();
let q = xserv_kernels::reshape_heads_gpu(&q, new_tokens, num_heads, head_dim);
let k = xserv_kernels::reshape_heads_gpu(&k, new_tokens, num_kv_heads, head_dim);
@@ -305,10 +358,12 @@ impl Qwen3 {
let (k_full, v_full) = paged_cache.gather_kv_contiguous(slot, layer_idx);
let attn_out = flash_attention(&q, &k_full, &v_full, true);
let attn_merged = xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
let attn_merged =
xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let (normed, x_new) =
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let residual = x_new.clone();
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt);
@@ -356,7 +411,9 @@ impl Qwen3 {
let qkv_all = matmul_2d(&normed, &layer.qkv_proj_wt);
let q_all = qkv_all.narrow(1, 0, layer.q_dim).contiguous();
let k_all = qkv_all.narrow(1, layer.q_dim, layer.kv_dim).contiguous();
let v_all = qkv_all.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous();
let v_all = qkv_all
.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim)
.contiguous();
let mut q_rows: Vec<Tensor> = Vec::with_capacity(batch);
for b in 0..batch {
@@ -394,14 +451,23 @@ impl Qwen3 {
let v_pool_ptr = paged_cache.v_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
let attn_out = xserv_kernels::paged_decode_attention(
&q_4d, k_pool_ptr, v_pool_ptr, bt_ptr, cl_ptr,
batch, num_heads, num_kv_heads, head_dim, max_blocks,
&q_4d,
k_pool_ptr,
v_pool_ptr,
bt_ptr,
cl_ptr,
batch,
num_heads,
num_kv_heads,
head_dim,
max_blocks,
);
let attn_merged = attn_out.reshape(&[batch, num_heads * head_dim]);
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let (normed, x_new) =
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let residual = x_new.clone();
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt);
@@ -441,7 +507,9 @@ impl Qwen3 {
let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32;
let mut x = embedding(&self.embed_tokens, token_ids);
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect();
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens)
.map(|p| p as u32)
.collect();
for (layer_idx, layer) in self.layers.iter().enumerate() {
let residual = x.clone();
@@ -450,7 +518,9 @@ impl Qwen3 {
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt);
let q = qkv.narrow(1, 0, layer.q_dim).contiguous();
let k = qkv.narrow(1, layer.q_dim, layer.kv_dim).contiguous();
let v = qkv.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous();
let v = qkv
.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim)
.contiguous();
let q = reshape_heads(&q, new_tokens, num_heads, head_dim);
let k = reshape_heads(&k, new_tokens, num_kv_heads, head_dim);
@@ -531,7 +601,9 @@ impl Qwen3 {
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt);
let q_all = qkv.narrow(1, 0, layer.q_dim).contiguous();
let k_all = qkv.narrow(1, layer.q_dim, layer.kv_dim).contiguous();
let v_all = qkv.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous();
let v_all = qkv
.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim)
.contiguous();
// Per-sequence: reshape, qk-norm, RoPE, KV cache, attention, merge
let mut attn_outputs: Vec<Tensor> = Vec::with_capacity(batch);
@@ -583,7 +655,8 @@ impl Qwen3 {
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
// Fused add + rmsnorm
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let (normed, x_new) =
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let residual = x_new.clone();
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt);
@@ -662,13 +735,15 @@ impl Qwen3 {
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt); // [B, (H+2*KV)*D]
let q_dim = num_heads * head_dim;
let kv_dim = num_kv_heads * head_dim;
let q_all = qkv.narrow(1, 0, q_dim); // [B, H*D] (view)
let k_all = qkv.narrow(1, q_dim, kv_dim); // [B, KV*D] (view)
let q_all = qkv.narrow(1, 0, q_dim); // [B, H*D] (view)
let k_all = qkv.narrow(1, q_dim, kv_dim); // [B, KV*D] (view)
let v_all = qkv.narrow(1, q_dim + kv_dim, kv_dim);
// Per-head RMSNorm on contiguous copies (narrow views are strided).
let q_flat = q_all.contiguous().reshape(&[batch * num_heads, head_dim]);
let k_flat = k_all.contiguous().reshape(&[batch * num_kv_heads, head_dim]);
let k_flat = k_all
.contiguous()
.reshape(&[batch * num_kv_heads, head_dim]);
let q_normed = rmsnorm(&q_flat, &layer.q_norm, eps);
let k_normed = rmsnorm(&k_flat, &layer.k_norm, eps);
@@ -688,8 +763,16 @@ impl Qwen3 {
let k_pool_ptr = paged_cache.k_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
let v_pool_ptr = paged_cache.v_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
let attn_out = xserv_kernels::paged_decode_attention(
&q_4d, k_pool_ptr, v_pool_ptr, bt_ptr, cl_ptr,
batch, num_heads, num_kv_heads, head_dim, max_blocks,
&q_4d,
k_pool_ptr,
v_pool_ptr,
bt_ptr,
cl_ptr,
batch,
num_heads,
num_kv_heads,
head_dim,
max_blocks,
);
// attn_out shape [B, H, 1, D] is contiguous-equivalent to [B, H*D].
@@ -697,7 +780,8 @@ impl Qwen3 {
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
self.all_reduce(&attn_proj); // TP: sum partial attention outputs
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let (normed, x_new) =
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let residual = x_new.clone();
// Fused gate+up projection: one GEMV instead of two.
@@ -743,7 +827,9 @@ impl Qwen3 {
paged_cache.advance_seq_len(slot, new_tokens);
let mut x = embedding(&self.embed_tokens, token_ids);
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect();
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens)
.map(|p| p as u32)
.collect();
for (layer_idx, layer) in self.layers.iter().enumerate() {
let residual = x.clone();
@@ -752,7 +838,9 @@ impl Qwen3 {
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt);
let q = qkv.narrow(1, 0, layer.q_dim).contiguous();
let k = qkv.narrow(1, layer.q_dim, layer.kv_dim).contiguous();
let v = qkv.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous();
let v = qkv
.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim)
.contiguous();
let q = xserv_kernels::reshape_heads_gpu(&q, new_tokens, num_heads, head_dim);
let k = xserv_kernels::reshape_heads_gpu(&k, new_tokens, num_kv_heads, head_dim);
@@ -773,11 +861,13 @@ impl Qwen3 {
let (k_full, v_full) = paged_cache.gather_kv_contiguous(slot, layer_idx);
let attn_out = flash_attention(&q, &k_full, &v_full, true);
let attn_merged = xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
let attn_merged =
xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
self.all_reduce(&attn_proj);
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let (normed, x_new) =
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let residual = x_new.clone();
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt);
@@ -798,14 +888,16 @@ impl Qwen3 {
pub fn forward_gpu_cache(&self, token_ids: &[u32], cache: &mut GpuKVCache) -> Tensor {
let new_tokens = token_ids.len();
let pos_offset = cache.seq_len();
let hidden = self.config.hidden();
let num_heads = self.config.num_heads();
let num_kv_heads = self.config.num_kv_heads();
let head_dim = self.config.head_dim();
let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32;
let mut x = embedding(&self.embed_tokens, token_ids);
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect();
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens)
.map(|p| p as u32)
.collect();
for (layer_idx, layer) in self.layers.iter().enumerate() {
let residual = x.clone();
@@ -814,7 +906,9 @@ impl Qwen3 {
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt);
let q = qkv.narrow(1, 0, layer.q_dim).contiguous();
let k = qkv.narrow(1, layer.q_dim, layer.kv_dim).contiguous();
let v = qkv.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous();
let v = qkv
.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim)
.contiguous();
let q = xserv_kernels::reshape_heads_gpu(&q, new_tokens, num_heads, head_dim);
let k = xserv_kernels::reshape_heads_gpu(&k, new_tokens, num_kv_heads, head_dim);
@@ -834,10 +928,12 @@ impl Qwen3 {
let (k_full, v_full) = cache.get_kv_len(layer_idx, pos_offset + new_tokens);
let attn_out = flash_attention(&q, &k_full, &v_full, true);
let attn_merged = xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
let attn_merged =
xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let (normed, x_new) =
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let residual = x_new.clone();
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt);
@@ -856,28 +952,33 @@ impl Qwen3 {
/// Extract weight pointers for CUDA Graph capture.
pub fn layer_weight_ptrs(&self) -> Vec<crate::decode_graph::LayerWeightPtrs> {
self.layers.iter().map(|l| crate::decode_graph::LayerWeightPtrs {
input_norm: l.input_norm.data_ptr() as *const std::ffi::c_void,
q_proj_wt: l.q_proj_wt().data_ptr() as *const std::ffi::c_void,
k_proj_wt: l.k_proj_wt().data_ptr() as *const std::ffi::c_void,
v_proj_wt: l.v_proj_wt().data_ptr() as *const std::ffi::c_void,
o_proj_wt: l.o_proj_wt.data_ptr() as *const std::ffi::c_void,
q_norm: l.q_norm.data_ptr() as *const std::ffi::c_void,
k_norm: l.k_norm.data_ptr() as *const std::ffi::c_void,
post_norm: l.post_norm.data_ptr() as *const std::ffi::c_void,
gate_proj_wt: l.gate_proj_wt().data_ptr() as *const std::ffi::c_void,
up_proj_wt: l.up_proj_wt().data_ptr() as *const std::ffi::c_void,
down_proj_wt: l.down_proj_wt.data_ptr() as *const std::ffi::c_void,
}).collect()
self.layers
.iter()
.map(|l| crate::decode_graph::LayerWeightPtrs {
input_norm: l.input_norm.data_ptr() as *const std::ffi::c_void,
q_proj_wt: l.q_proj_wt().data_ptr() as *const std::ffi::c_void,
k_proj_wt: l.k_proj_wt().data_ptr() as *const std::ffi::c_void,
v_proj_wt: l.v_proj_wt().data_ptr() as *const std::ffi::c_void,
o_proj_wt: l.o_proj_wt.data_ptr() as *const std::ffi::c_void,
q_norm: l.q_norm.data_ptr() as *const std::ffi::c_void,
k_norm: l.k_norm.data_ptr() as *const std::ffi::c_void,
post_norm: l.post_norm.data_ptr() as *const std::ffi::c_void,
gate_proj_wt: l.gate_proj_wt().data_ptr() as *const std::ffi::c_void,
up_proj_wt: l.up_proj_wt().data_ptr() as *const std::ffi::c_void,
down_proj_wt: l.down_proj_wt.data_ptr() as *const std::ffi::c_void,
})
.collect()
}
/// Get pointers needed for CUDA Graph capture.
pub fn graph_capture_ptrs(&self) -> (
*const std::ffi::c_void, // norm weight
*const std::ffi::c_void, // lm_head_t
*const std::ffi::c_void, // embed_tokens
*const std::ffi::c_void, // rope cos
*const std::ffi::c_void, // rope sin
pub fn graph_capture_ptrs(
&self,
) -> (
*const std::ffi::c_void, // norm weight
*const std::ffi::c_void, // lm_head_t
*const std::ffi::c_void, // embed_tokens
*const std::ffi::c_void, // rope cos
*const std::ffi::c_void, // rope sin
) {
(
self.norm.data_ptr() as *const std::ffi::c_void,
@@ -895,11 +996,16 @@ impl Qwen3 {
/// (column-parallel split: split the OUTPUT dim). `world==1` returns the whole.
/// Input must be a contiguous CPU (or device) BF16 tensor.
fn shard_rows(t: &Tensor, rank: usize, world: usize) -> Tensor {
if world == 1 { return t.clone(); }
if world == 1 {
return t.clone();
}
let shape = t.shape();
assert_eq!(shape.len(), 2, "shard_rows expects 2D weight");
let (rows, cols) = (shape[0], shape[1]);
assert!(rows % world == 0, "rows {rows} not divisible by world {world}");
assert!(
rows % world == 0,
"rows {rows} not divisible by world {world}"
);
let local = rows / world;
let host = t.to_device(Device::Cpu);
let data = host.as_slice::<bf16>();
@@ -911,11 +1017,16 @@ fn shard_rows(t: &Tensor, rank: usize, world: usize) -> Tensor {
/// Keep this rank's column-block of a 2D `[rows, cols]` BF16 tensor (row-parallel
/// split: split the INPUT dim). Strided copy. `world==1` returns the whole.
fn shard_cols(t: &Tensor, rank: usize, world: usize) -> Tensor {
if world == 1 { return t.clone(); }
if world == 1 {
return t.clone();
}
let shape = t.shape();
assert_eq!(shape.len(), 2, "shard_cols expects 2D weight");
let (rows, cols) = (shape[0], shape[1]);
assert!(cols % world == 0, "cols {cols} not divisible by world {world}");
assert!(
cols % world == 0,
"cols {cols} not divisible by world {world}"
);
let local = cols / world;
let c0 = rank * local;
let host = t.to_device(Device::Cpu);
@@ -1009,7 +1120,9 @@ fn transpose_from_rope(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: u
}
fn repeat_kv(x: &Tensor, n_rep: usize) -> Tensor {
if n_rep == 1 { return x.clone(); }
if n_rep == 1 {
return x.clone();
}
let kv_heads = x.shape()[1];
let seq_len = x.shape()[2];
let head_dim = x.shape()[3];
@@ -1065,11 +1178,16 @@ fn concat_rows(rows: &[Tensor]) -> Tensor {
let src_buf = row.storage().gpu_buffer();
let src_offset = row.offset() * elem_size;
let dst_offset = b * row_bytes;
out_buf.copy_from_device_at(src_buf, src_offset, dst_offset, row_bytes).unwrap();
out_buf
.copy_from_device_at(src_buf, src_offset, dst_offset, row_bytes)
.unwrap();
}
// Wrap in a Tensor
let device_id = match device { Device::Cuda(id) => id, _ => panic!("expected CUDA device") };
let device_id = match device {
Device::Cuda(id) => id,
_ => panic!("expected CUDA device"),
};
unsafe {
crate::kv_cache::tensor_from_gpu_buffer_pub(out_buf, &[batch, cols], dtype, device_id)
}
@@ -1082,12 +1200,15 @@ fn cat_cols(tensors: &[&Tensor]) -> Tensor {
let dtype = tensors[0].dtype();
let device = tensors[0].device();
let elem = dtype.size_bytes();
let total_cols: usize = tensors.iter().map(|t| {
assert_eq!(t.ndim(), 2);
assert_eq!(t.shape()[0], rows);
assert!(t.is_contiguous());
t.shape()[1]
}).sum();
let total_cols: usize = tensors
.iter()
.map(|t| {
assert_eq!(t.ndim(), 2);
assert_eq!(t.shape()[0], rows);
assert!(t.is_contiguous());
t.shape()[1]
})
.sum();
let out = Tensor::empty(&[rows, total_cols], dtype, device);
let dst_base = out.data_ptr() as *mut u8;
for r in 0..rows {
@@ -1126,7 +1247,9 @@ pub fn sample_greedy(logits: &Tensor) -> u32 {
let seq_len = logits.shape()[0];
let data = logits_cpu.as_slice::<bf16>();
let last = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size];
last.iter().enumerate()
last.iter()
.enumerate()
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
.map(|(i, _)| i as u32).unwrap()
.map(|(i, _)| i as u32)
.unwrap()
}

View File

@@ -11,7 +11,11 @@ pub struct SamplingParams {
impl Default for SamplingParams {
fn default() -> Self {
Self { temperature: 0.0, top_k: 0, top_p: 1.0 }
Self {
temperature: 0.0,
top_k: 0,
top_p: 1.0,
}
}
}
@@ -19,6 +23,18 @@ impl Default for SamplingParams {
/// Uses the last position's logits. Handles both F32 and BF16 dtypes.
pub fn sample(logits: &Tensor, params: &SamplingParams) -> u32 {
assert_eq!(logits.ndim(), 2);
// Greedy fast path: GPU argmax + 4-byte D2H instead of copying the whole
// [seq, vocab] logits to the host and scanning it (~201k bf16/token).
// NaN logits lose every `>` comparison in the kernel, matching the
// NaN-safe host argmax below.
if params.temperature == 0.0
&& logits.dtype() == DType::BF16
&& matches!(logits.device(), Device::Cuda(_))
&& logits.is_contiguous()
{
let ids = xserv_kernels::argmax_bf16_to_host(logits);
return *ids.last().unwrap();
}
let vocab_size = logits.shape()[1];
let seq_len = logits.shape()[0];
let logits_cpu = logits.to_device(Device::Cpu);
@@ -122,9 +138,14 @@ pub fn sample_greedy_penalized(logits: &Tensor, recent: &[u32], penalty: f32) ->
let seq_len = logits.shape()[0];
let logits_cpu = logits.to_device(Device::Cpu);
let mut last_row: Vec<f32> = match logits.dtype() {
DType::F32 => logits_cpu.as_slice::<f32>()[(seq_len - 1) * vocab_size..seq_len * vocab_size].to_vec(),
DType::BF16 => logits_cpu.as_slice::<bf16>()[(seq_len - 1) * vocab_size..seq_len * vocab_size]
.iter().map(|v| v.to_f32()).collect(),
DType::F32 => {
logits_cpu.as_slice::<f32>()[(seq_len - 1) * vocab_size..seq_len * vocab_size].to_vec()
}
DType::BF16 => logits_cpu.as_slice::<bf16>()
[(seq_len - 1) * vocab_size..seq_len * vocab_size]
.iter()
.map(|v| v.to_f32())
.collect(),
_ => panic!("unsupported dtype for sampling: {:?}", logits.dtype()),
};
if penalty > 1.0 {

View File

@@ -72,7 +72,10 @@ impl ChatTemplate {
let source = std::fs::read_to_string(&jinja_path)
.unwrap_or_else(|e| panic!("failed to read {}: {e}", jinja_path.display()));
eprintln!("[chat-template] loaded from {}", jinja_path.display());
return Self { source, model_type: model_type.to_string() };
return Self {
source,
model_type: model_type.to_string(),
};
}
// 2. Try tokenizer_config.json → chat_template field
@@ -82,7 +85,10 @@ impl ChatTemplate {
if let Ok(v) = serde_json::from_str::<serde_json::Value>(&data) {
if let Some(ct) = v.get("chat_template").and_then(|v| v.as_str()) {
eprintln!("[chat-template] loaded from tokenizer_config.json");
return Self { source: ct.to_string(), model_type: model_type.to_string() };
return Self {
source: ct.to_string(),
model_type: model_type.to_string(),
};
}
}
}
@@ -90,7 +96,10 @@ impl ChatTemplate {
// 3. No template found — use empty source, will fall back to hardcoded
eprintln!("[chat-template] no Jinja template found, using hardcoded fallback");
Self { source: String::new(), model_type: model_type.to_string() }
Self {
source: String::new(),
model_type: model_type.to_string(),
}
}
pub fn render(&self, messages: &[Message]) -> String {
@@ -206,7 +215,10 @@ fn build_prompt_gpt_oss(messages: &[Message]) -> String {
prompt.push_str("<|start|>system<|message|>");
prompt.push_str("You are ChatGPT, a large language model trained by OpenAI.\n");
prompt.push_str("Knowledge cutoff: 2024-06\n");
prompt.push_str(&format!("Current date: {}\n\n", strftime_now("%Y-%m-%d".to_string())));
prompt.push_str(&format!(
"Current date: {}\n\n",
strftime_now("%Y-%m-%d".to_string())
));
prompt.push_str("Reasoning: low\n\n");
prompt.push_str("# Valid channels: analysis, commentary, final. Channel must be included for every message.");
prompt.push_str("<|end|>");
@@ -334,13 +346,11 @@ async fn chat_non_stream(state: Arc<AppState>, req: ChatRequest) -> Response {
"completion_tokens": completion_token_count,
"total_tokens": prompt_token_count + completion_token_count
}
})).into_response()
}))
.into_response()
}
fn chat_stream(
state: Arc<AppState>,
req: ChatRequest,
) -> Response {
fn chat_stream(state: Arc<AppState>, req: ChatRequest) -> Response {
let id = format!("chatcmpl-{}", Uuid::new_v4());
let model_name = state.model_name.clone();
let created = unix_timestamp();
@@ -356,7 +366,8 @@ fn chat_stream(
if prompt_tokens.len() >= max_seq_len {
return bad_request(format!(
"prompt is {} tokens, exceeds max_seq_len {}",
prompt_tokens.len(), max_seq_len
prompt_tokens.len(),
max_seq_len
));
}
let max_tokens = req.max_tokens.min(max_seq_len - prompt_tokens.len());
@@ -413,7 +424,9 @@ fn chat_stream(
}
});
Sse::new(ReceiverStream::new(sse_rx)).keep_alive(KeepAlive::default()).into_response()
Sse::new(ReceiverStream::new(sse_rx))
.keep_alive(KeepAlive::default())
.into_response()
}
fn validate_request(req: &ChatRequest, model_name: &str) -> Option<Response> {
@@ -436,8 +449,13 @@ fn validate_request(req: &ChatRequest, model_name: &str) -> Option<Response> {
/// prior handler panicked) and returns a clean 503 instead of panicking when the
/// engine thread is gone, so one dead engine doesn't cascade into every request.
fn submit_to_engine(state: &AppState, req: GenerateRequest) -> Result<(), Response> {
let sender = state.engine_sender.lock().unwrap_or_else(|e| e.into_inner());
sender.send(req).map_err(|_| service_unavailable("inference engine is not available"))
let sender = state
.engine_sender
.lock()
.unwrap_or_else(|e| e.into_inner());
sender
.send(req)
.map_err(|_| service_unavailable("inference engine is not available"))
}
fn service_unavailable(message: impl Into<String>) -> Response {

View File

@@ -1,10 +1,10 @@
use std::collections::VecDeque;
use std::path::Path;
use std::sync::mpsc;
use std::sync::Once;
use std::sync::mpsc;
use std::time::Instant;
use xserv_model::{ModelConfig, PagedKVCache, Qwen3, SamplingParams, sample, BLOCK_SIZE};
use xserv_model::loader;
use xserv_model::{BLOCK_SIZE, ModelConfig, PagedKVCache, Qwen3, SamplingParams, sample};
use xserv_tensor::{DType, Device};
use xserv_tokenizer::Tokenizer;
@@ -109,12 +109,23 @@ impl Engine {
(total_blocks * bytes_per_block) as f64 / 1e9,
info.free_memory as f64 / 1e9,
);
Self { model, config, tokenizer, max_batch_size, max_seq_len, paged_cache }
Self {
model,
config,
tokenizer,
max_batch_size,
max_seq_len,
paged_cache,
}
}
pub fn tokenizer(&self) -> &Tokenizer { &self.tokenizer }
pub fn tokenizer(&self) -> &Tokenizer {
&self.tokenizer
}
pub fn max_seq_len(&self) -> usize { self.max_seq_len }
pub fn max_seq_len(&self) -> usize {
self.max_seq_len
}
/// Main scheduler loop. Receives requests from channel, manages concurrent sequences.
///
@@ -134,7 +145,8 @@ impl Engine {
loop {
// Step 1: Remove finished sequences and return their slots.
let finished_slots: Vec<usize> = running.iter()
let finished_slots: Vec<usize> = running
.iter()
.filter(|s| is_finished(s))
.filter_map(|s| s.seq_slot)
.collect();
@@ -147,10 +159,16 @@ impl Engine {
// room (oldest first). They resume decoding from where they paused.
while running.len() < self.max_batch_size && !swapped.is_empty() {
let slot = swapped[0].seq_slot.expect("swapped slot");
if !self.paged_cache.can_swap_in(slot) { break; }
if !self.paged_cache.can_swap_in(slot) {
break;
}
self.paged_cache.swap_in(slot).expect("swap_in");
let seq = swapped.remove(0);
eprintln!("[scheduler] swapped in seq {} ({} blocks)", seq.id, self.paged_cache.block_count(slot));
eprintln!(
"[scheduler] swapped in seq {} ({} blocks)",
seq.id,
self.paged_cache.block_count(slot)
);
running.push(seq);
}
@@ -161,14 +179,22 @@ impl Engine {
let mut avail = self.paged_cache.free_blocks();
let decode_reserve = running.len();
while running.len() < self.max_batch_size {
let Some(front) = waiting.front() else { break; };
let Some(front) = waiting.front() else {
break;
};
let prompt_blocks = front.prompt_tokens.len().div_ceil(BLOCK_SIZE).max(1);
if avail < prompt_blocks + decode_reserve { break; }
if avail < prompt_blocks + decode_reserve {
break;
}
let free_slot = (0..self.paged_cache.max_seqs())
.find(|&s| self.paged_cache.is_slot_free(s));
let Some(slot) = free_slot else { break; };
let Some(slot) = free_slot else {
break;
};
let mut seq = waiting.pop_front().unwrap();
self.paged_cache.register_sequence(slot).expect("register paged slot");
self.paged_cache
.register_sequence(slot)
.expect("register paged slot");
seq.seq_slot = Some(slot);
running.push(seq);
avail -= prompt_blocks; // projected free after this seq prefills
@@ -199,7 +225,9 @@ impl Engine {
if !seq.prefilled {
let slot = seq.seq_slot.expect("slot");
let logits = self.model.forward_prefill_paged(
&seq.prompt_tokens, slot, &mut self.paged_cache,
&seq.prompt_tokens,
slot,
&mut self.paged_cache,
);
let next = sample(&logits, &seq.sampling);
seq.generated_tokens.push(next);
@@ -219,13 +247,18 @@ impl Engine {
&& !newly_prefilled.contains(&running[p].id)
&& running[p].seq_slot.is_some()
});
let Some(pos) = victim else { break; };
let Some(pos) = victim else {
break;
};
let seq = running.remove(pos);
let slot = seq.seq_slot.unwrap();
if self.paged_cache.can_swap_out(slot) {
let nblocks = self.paged_cache.block_count(slot);
self.paged_cache.swap_out(slot).expect("swap_out");
eprintln!("[scheduler] preempt: swapped out seq {} ({nblocks} blocks) to host", seq.id);
eprintln!(
"[scheduler] preempt: swapped out seq {} ({nblocks} blocks) to host",
seq.id
);
swapped.push(seq);
needed = decode_block_need(&self.paged_cache, &running, &newly_prefilled);
} else {
@@ -235,7 +268,9 @@ impl Engine {
}
// Step 5c: Batched paged decode for the surviving prefilled sequences.
let decode_indices: Vec<usize> = running.iter().enumerate()
let decode_indices: Vec<usize> = running
.iter()
.enumerate()
.filter(|(_, s)| s.prefilled && !newly_prefilled.contains(&s.id))
.map(|(i, _)| i)
.collect();
@@ -246,25 +281,32 @@ impl Engine {
eprintln!("[scheduler] paged decode active");
});
let tokens: Vec<u32> = decode_indices.iter()
let tokens: Vec<u32> = decode_indices
.iter()
.map(|&i| *running[i].generated_tokens.last().unwrap())
.collect();
let positions: Vec<usize> = decode_indices.iter()
let positions: Vec<usize> = decode_indices
.iter()
.map(|&i| self.paged_cache.seq_len(running[i].seq_slot.unwrap()))
.collect();
let slots: Vec<usize> = decode_indices.iter()
let slots: Vec<usize> = decode_indices
.iter()
.map(|&i| running[i].seq_slot.unwrap())
.collect();
let logits = self.model.forward_decode_paged(
&tokens, &positions, &slots, &mut self.paged_cache,
&tokens,
&positions,
&slots,
&mut self.paged_cache,
);
// Fast path: every active sequence is greedy → run argmax on
// the GPU and only D2H the chosen token ids (a few bytes per
// sequence) instead of the full [B, vocab_size] BF16 logits
// (~1.2 MB for B=4, Qwen3 vocab=152K).
let all_greedy = decode_indices.iter()
let all_greedy = decode_indices
.iter()
.all(|&i| running[i].sampling.temperature == 0.0);
if all_greedy {
let next_ids = xserv_kernels::argmax_bf16_to_host(&logits);
@@ -285,11 +327,15 @@ impl Engine {
let row_start = j * vocab_size;
let row_logits = &data[row_start..row_start + vocab_size];
let next = if running[i].sampling.temperature == 0.0 {
row_logits.iter().enumerate()
row_logits
.iter()
.enumerate()
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
.map(|(idx, _)| idx as u32).unwrap()
.map(|(idx, _)| idx as u32)
.unwrap()
} else {
let row_tensor = xserv_tensor::Tensor::from_slice(row_logits, &[1, vocab_size]);
let row_tensor =
xserv_tensor::Tensor::from_slice(row_logits, &[1, vocab_size]);
sample(&row_tensor, &running[i].sampling)
};
running[i].generated_tokens.push(next);
@@ -334,7 +380,8 @@ impl Engine {
/// Total additional GPU blocks the next decode step needs across all
/// currently-decoding (prefilled, not just-prefilled) sequences.
fn decode_block_need(paged: &PagedKVCache, running: &[Sequence], newly_prefilled: &[u64]) -> usize {
running.iter()
running
.iter()
.filter(|s| s.prefilled && !newly_prefilled.contains(&s.id))
.filter_map(|s| s.seq_slot)
.map(|slot| paged.additional_blocks_needed(slot, 1))
@@ -372,8 +419,12 @@ fn send_token_if_nonempty(seq: &Sequence, text: String) {
}
fn is_finished(seq: &Sequence) -> bool {
if seq.generated_tokens.is_empty() { return false; }
if seq.generated_tokens.is_empty() {
return false;
}
let last = *seq.generated_tokens.last().unwrap();
if seq.generated_tokens.len() >= seq.max_tokens { return true; }
if seq.generated_tokens.len() >= seq.max_tokens {
return true;
}
seq.sender.is_closed() || seq.eos_token_id == Some(last)
}

View File

@@ -3,10 +3,13 @@ mod engine;
mod pp_engine;
mod tp_engine;
use axum::{routing::{get, post}, Extension, Router};
use std::path::PathBuf;
use std::sync::{mpsc, Arc, Mutex};
use axum::{
Extension, Router,
routing::{get, post},
};
use engine::GenerateRequest;
use std::path::PathBuf;
use std::sync::{Arc, Mutex, mpsc};
use xserv_model::ModelConfig;
pub struct AppState {
@@ -21,40 +24,48 @@ pub struct AppState {
async fn main() {
let args: Vec<String> = std::env::args().collect();
if args.len() < 2 {
eprintln!("Usage: xserv-server <model-dir> [--port PORT] [--max-batch N] [--max-seq-len N] [--swap-space-gb N] [--tp N] [--pp N]");
eprintln!(
"Usage: xserv-server <model-dir> [--port PORT] [--max-batch N] [--max-seq-len N] [--swap-space-gb N] [--tp N] [--pp N]"
);
std::process::exit(1);
}
let model_dir = PathBuf::from(&args[1]);
let port: u16 = args.iter()
let port: u16 = args
.iter()
.position(|a| a == "--port")
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse().ok())
.unwrap_or(8080);
let max_batch: usize = args.iter()
let max_batch: usize = args
.iter()
.position(|a| a == "--max-batch")
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse().ok())
.unwrap_or(4)
.max(1);
let requested_max_seq_len: usize = args.iter()
let requested_max_seq_len: usize = args
.iter()
.position(|a| a == "--max-seq-len")
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse().ok())
.unwrap_or(2048)
.max(1);
let swap_space_gb: usize = args.iter()
let swap_space_gb: usize = args
.iter()
.position(|a| a == "--swap-space-gb")
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse().ok())
.unwrap_or(8);
let tp: usize = args.iter()
let tp: usize = args
.iter()
.position(|a| a == "--tp")
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse().ok())
.unwrap_or(1)
.max(1);
let pp: usize = args.iter()
let pp: usize = args
.iter()
.position(|a| a == "--pp")
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse().ok())
@@ -65,6 +76,15 @@ async fn main() {
std::process::exit(1);
}
let model_config = ModelConfig::from_file(&model_dir.join("config.json"));
// gpt-oss is only implemented in the TP engine; route it there even at
// tp=1 (single-rank world) so quantized models can serve on one GPU.
let is_gpt_oss = model_config.model_type.as_deref() == Some("gpt_oss");
if pp > 1 && is_gpt_oss {
eprintln!(
"gpt-oss is not supported by the pipeline-parallel engine (Qwen3 only); use --tp instead"
);
std::process::exit(1);
}
let model_max_seq_len = model_config.max_seq_len();
if model_max_seq_len == 0 {
eprintln!("model config has invalid max_seq_len=0");
@@ -77,7 +97,8 @@ async fn main() {
);
}
let model_name = model_dir.file_name()
let model_name = model_dir
.file_name()
.map(|n| n.to_string_lossy().to_string())
.unwrap_or_else(|| "unknown".to_string());
@@ -91,8 +112,13 @@ async fn main() {
if pp > 1 {
// Pipeline-parallel path: stage-0 coordinator + worker stage threads.
pp_engine::run_pp(&model_dir_clone, pp, max_seq_len, rx);
} else if tp <= 1 {
let mut engine = engine::Engine::load_with_swap(&model_dir_clone, max_batch, max_seq_len, swap_space_gb);
} else if tp <= 1 && !is_gpt_oss {
let mut engine = engine::Engine::load_with_swap(
&model_dir_clone,
max_batch,
max_seq_len,
swap_space_gb,
);
engine.run(rx);
} else {
// Tensor-parallel path: rank-0 coordinator + worker rank threads.

View File

@@ -15,15 +15,15 @@
use std::ffi::c_void;
use std::path::{Path, PathBuf};
use std::sync::mpsc;
use std::sync::Arc;
use std::sync::mpsc;
use std::thread;
use half::bf16;
use xserv_distributed::{PpContext, UniqueId};
use xserv_model::loader;
use xserv_model::sampling::SamplingParams;
use xserv_model::{sample, ModelConfig, PagedKVCache, Qwen3, BLOCK_SIZE};
use xserv_model::{BLOCK_SIZE, ModelConfig, PagedKVCache, Qwen3, sample};
use xserv_tensor::{DType, Device, Tensor};
use xserv_tokenizer::Tokenizer;
@@ -38,9 +38,16 @@ enum PpCommand {
Free(usize),
/// Receive `[n_tokens, hidden]` from the previous stage, run this stage's
/// layers; if last stage, sample with `sampling` and return the token.
Prefill { n_tokens: usize, slot: usize, sampling: SamplingParams },
Prefill {
n_tokens: usize,
slot: usize,
sampling: SamplingParams,
},
/// Receive `[1, hidden]`, run this stage's layers; last stage samples.
Decode { slot: usize, sampling: SamplingParams },
Decode {
slot: usize,
sampling: SamplingParams,
},
Shutdown,
}
@@ -76,9 +83,21 @@ fn build_stage(
let max_blocks_per_seq = max_seq_len.div_ceil(BLOCK_SIZE);
let total_blocks = max_blocks_per_seq + 8; // v1 serial: one active sequence
let cache = PagedKVCache::new(
&stage_config, total_blocks, 0, 4, max_blocks_per_seq, DType::BF16, device,
&stage_config,
total_blocks,
0,
4,
max_blocks_per_seq,
DType::BF16,
device,
);
StageCtx { model, cache, pp, hidden: config.hidden(), device }
StageCtx {
model,
cache,
pp,
hidden: config.hidden(),
device,
}
}
/// Allocate a zeroed `[n, hidden]` device tensor and receive into it from `peer`.
@@ -110,7 +129,15 @@ fn worker_loop(
ack_tx: mpsc::Sender<()>,
token_tx: mpsc::Sender<u32>,
) {
let mut sc = build_stage(&model_dir, &config, stage, world, stage as u32, max_seq_len, id);
let mut sc = build_stage(
&model_dir,
&config,
stage,
world,
stage as u32,
max_seq_len,
id,
);
let is_last = stage == world - 1;
let prev = stage - 1;
let next = stage + 1;
@@ -125,7 +152,11 @@ fn worker_loop(
sc.cache.free_sequence(slot);
let _ = ack_tx.send(());
}
PpCommand::Prefill { n_tokens, slot, sampling } => {
PpCommand::Prefill {
n_tokens,
slot,
sampling,
} => {
let x = recv_hidden(&sc, n_tokens, prev);
let x = sc.model.forward_layers_prefill(x, slot, &mut sc.cache);
if is_last {
@@ -155,7 +186,12 @@ fn worker_loop(
/// Run the PP coordinator (stage 0) on the calling thread. Spawns worker stages
/// 1..world and consumes generation requests from `rx`.
pub fn run_pp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Receiver<GenerateRequest>) {
pub fn run_pp(
model_dir: &Path,
world: usize,
max_seq_len: usize,
rx: mpsc::Receiver<GenerateRequest>,
) {
assert!(world >= 2, "run_pp requires world >= 2");
let config = ModelConfig::from_file(&model_dir.join("config.json"));
assert!(
@@ -179,7 +215,17 @@ pub fn run_pp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
let model_dir = model_dir.to_path_buf();
let config = config.clone();
thread::spawn(move || {
worker_loop(stage, world, id, model_dir, config, max_seq_len, ctx_rx, ack_tx, token_tx);
worker_loop(
stage,
world,
id,
model_dir,
config,
max_seq_len,
ctx_rx,
ack_tx,
token_tx,
);
});
}
@@ -207,11 +253,14 @@ pub fn run_pp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
wait_acks(&ack_rx);
// Prefill: embed prompt, run stage-0 layers, push hidden into the pipe.
broadcast(&cmd_txs, PpCommand::Prefill {
n_tokens: req.prompt_tokens.len(),
slot,
sampling: req.sampling.clone(),
});
broadcast(
&cmd_txs,
PpCommand::Prefill {
n_tokens: req.prompt_tokens.len(),
slot,
sampling: req.sampling.clone(),
},
);
let x = sc.model.embed(&req.prompt_tokens);
let x = sc.model.forward_layers_prefill(x, slot, &mut sc.cache);
send_hidden(&sc, &x, next_peer);
@@ -228,7 +277,13 @@ pub fn run_pp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
if generated >= req.max_tokens {
break "length";
}
broadcast(&cmd_txs, PpCommand::Decode { slot, sampling: req.sampling.clone() });
broadcast(
&cmd_txs,
PpCommand::Decode {
slot,
sampling: req.sampling.clone(),
},
);
let x = sc.model.embed(&[next]);
let x = sc.model.forward_layers_decode(x, &[slot], &mut sc.cache);
send_hidden(&sc, &x, next_peer);
@@ -239,9 +294,14 @@ pub fn run_pp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
let tail = tokenizer.flush_decode_stream(&mut decode_buf);
if !tail.is_empty() {
let _ = req.sender.blocking_send(GenerateEvent::Token { id: next, text: tail });
let _ = req.sender.blocking_send(GenerateEvent::Token {
id: next,
text: tail,
});
}
let _ = req.sender.blocking_send(GenerateEvent::Done { finish_reason: finish.to_string() });
let _ = req.sender.blocking_send(GenerateEvent::Done {
finish_reason: finish.to_string(),
});
broadcast(&cmd_txs, PpCommand::Free(slot));
sc.cache.free_sequence(slot);
@@ -258,6 +318,8 @@ fn emit_text(tokenizer: &Tokenizer, req: &GenerateRequest, token_id: u32, buf: &
}
let text = tokenizer.decode_token_stream(token_id, buf);
if !text.is_empty() {
let _ = req.sender.blocking_send(GenerateEvent::Token { id: token_id, text });
let _ = req
.sender
.blocking_send(GenerateEvent::Token { id: token_id, text });
}
}

View File

@@ -13,13 +13,16 @@
//! work; the single-GPU `Engine` still handles TP=1.
use std::path::{Path, PathBuf};
use std::sync::mpsc;
use std::sync::Arc;
use std::sync::mpsc;
use std::thread;
use xserv_distributed::{TpContext, UniqueId};
use xserv_model::loader;
use xserv_model::{sample, sample_greedy_penalized, GptOss, ModelConfig, PagedKVCache, Qwen3, BLOCK_SIZE};
use xserv_model::{
BLOCK_SIZE, GptOss, GraphedGptOssDecoder, ModelConfig, PagedKVCache, Qwen3, sample,
sample_greedy_penalized,
};
use xserv_tensor::{DType, Device, Tensor};
use xserv_tokenizer::Tokenizer;
@@ -29,8 +32,15 @@ use crate::engine::{GenerateEvent, GenerateRequest};
enum TpCommand {
Register(usize),
Free(usize),
Prefill { tokens: Vec<u32>, slot: usize },
Decode { tokens: Vec<u32>, positions: Vec<usize>, slots: Vec<usize> },
Prefill {
tokens: Vec<u32>,
slot: usize,
},
Decode {
tokens: Vec<u32>,
positions: Vec<usize>,
slots: Vec<usize>,
},
Shutdown,
}
@@ -40,14 +50,25 @@ enum TpModel {
}
impl TpModel {
fn forward_prefill_paged(&self, tokens: &[u32], slot: usize, cache: &mut PagedKVCache) -> Tensor {
fn forward_prefill_paged(
&self,
tokens: &[u32],
slot: usize,
cache: &mut PagedKVCache,
) -> Tensor {
match self {
TpModel::Qwen3(m) => m.forward_prefill_paged(tokens, slot, cache),
TpModel::GptOss(m) => m.forward_prefill_paged(tokens, slot, cache),
}
}
fn forward_decode_paged(&self, tokens: &[u32], positions: &[usize], slots: &[usize], cache: &mut PagedKVCache) -> Tensor {
fn forward_decode_paged(
&self,
tokens: &[u32],
positions: &[usize],
slots: &[usize],
cache: &mut PagedKVCache,
) -> Tensor {
match self {
TpModel::Qwen3(m) => m.forward_decode_paged(tokens, positions, slots, cache),
TpModel::GptOss(m) => m.forward_decode_paged(tokens, positions, slots, cache),
@@ -58,6 +79,20 @@ impl TpModel {
struct RankCtx {
model: TpModel,
cache: PagedKVCache,
decoder: GraphedGptOssDecoder,
}
/// Decode one step: gpt-oss batch=1 goes through the CUDA-graph decoder
/// (lazy capture, replay thereafter); everything else runs eager.
fn rank_decode(rc: &mut RankCtx, tokens: &[u32], positions: &[usize], slots: &[usize]) -> Tensor {
match &rc.model {
TpModel::GptOss(m) => rc
.decoder
.decode(m, tokens, positions, slots, &mut rc.cache),
TpModel::Qwen3(_) => rc
.model
.forward_decode_paged(tokens, positions, slots, &mut rc.cache),
}
}
fn build_rank(
@@ -71,17 +106,42 @@ fn build_rank(
) -> RankCtx {
let weights = loader::load_model_dir(model_dir, Device::Cpu);
let model = if config.is_moe() {
TpModel::GptOss(GptOss::from_weights_tp(config.clone(), weights, rank, world, device, tp))
TpModel::GptOss(GptOss::from_weights_tp(
config.clone(),
weights,
rank,
world,
device,
tp,
))
} else {
TpModel::Qwen3(Qwen3::from_weights_tp(config.clone(), weights, rank, world, device, tp))
TpModel::Qwen3(Qwen3::from_weights_tp(
config.clone(),
weights,
rank,
world,
device,
tp,
))
};
let local_kv = config.num_kv_heads() / world;
let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
let total_blocks = max_blocks_per_seq + 8;
let cache = PagedKVCache::new_tp(
config, local_kv, total_blocks, 0, 4, max_blocks_per_seq, DType::BF16, device,
config,
local_kv,
total_blocks,
0,
4,
max_blocks_per_seq,
DType::BF16,
device,
);
RankCtx { model, cache }
RankCtx {
model,
cache,
decoder: GraphedGptOssDecoder::new(),
}
}
fn worker_loop(
@@ -95,7 +155,15 @@ fn worker_loop(
ack_tx: mpsc::Sender<()>,
) {
let tp = Arc::new(TpContext::init(rank, world, id, rank as u32));
let mut rc = build_rank(&model_dir, &config, rank, world, rank as u32, max_seq_len, Some(tp));
let mut rc = build_rank(
&model_dir,
&config,
rank,
world,
rank as u32,
max_seq_len,
Some(tp),
);
while let Ok(cmd) = cmd_rx.recv() {
match cmd {
TpCommand::Register(slot) => {
@@ -105,8 +173,12 @@ fn worker_loop(
TpCommand::Prefill { tokens, slot } => {
let _ = rc.model.forward_prefill_paged(&tokens, slot, &mut rc.cache);
}
TpCommand::Decode { tokens, positions, slots } => {
let _ = rc.model.forward_decode_paged(&tokens, &positions, &slots, &mut rc.cache);
TpCommand::Decode {
tokens,
positions,
slots,
} => {
let _ = rank_decode(&mut rc, &tokens, &positions, &slots);
}
TpCommand::Shutdown => {
let _ = ack_tx.send(());
@@ -119,8 +191,15 @@ fn worker_loop(
/// Run the TP coordinator (rank 0) on the calling thread. Spawns worker ranks
/// internally and consumes generation requests from `rx`.
pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Receiver<GenerateRequest>) {
assert!(world >= 2, "run_tp requires world >= 2");
pub fn run_tp(
model_dir: &Path,
world: usize,
max_seq_len: usize,
rx: mpsc::Receiver<GenerateRequest>,
) {
// world=1 is a valid single-rank configuration (gpt-oss has no
// single-GPU engine path; NCCL init and all_reduce no-op at world=1).
assert!(world >= 1, "run_tp requires world >= 1");
let config = ModelConfig::from_file(&model_dir.join("config.json"));
assert!(
config.num_kv_heads() % world == 0,
@@ -140,7 +219,16 @@ pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
let model_dir = model_dir.to_path_buf();
let config = config.clone();
thread::spawn(move || {
worker_loop(rank, world, id, model_dir, config, max_seq_len, ctx_rx, ack_tx);
worker_loop(
rank,
world,
id,
model_dir,
config,
max_seq_len,
ctx_rx,
ack_tx,
);
});
}
@@ -153,10 +241,14 @@ pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
// models loop under pure greedy when numerics diverge from the reference).
// Off by default; XSERV_REP_PENALTY>1 enables it over the last
// XSERV_REP_WINDOW generated tokens. Applied only on the greedy path.
let rep_penalty: f32 = std::env::var("XSERV_REP_PENALTY").ok()
.and_then(|s| s.parse().ok()).unwrap_or(1.0);
let rep_window: usize = std::env::var("XSERV_REP_WINDOW").ok()
.and_then(|s| s.parse().ok()).unwrap_or(128);
let rep_penalty: f32 = std::env::var("XSERV_REP_PENALTY")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(1.0);
let rep_window: usize = std::env::var("XSERV_REP_WINDOW")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(128);
let pick = |logits: &Tensor, sp: &xserv_model::SamplingParams, history: &[u32]| -> u32 {
if rep_penalty > 1.0 && sp.temperature == 0.0 {
let start = history.len().saturating_sub(rep_window);
@@ -185,8 +277,16 @@ pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
wait_acks(&ack_rx);
// Prefill.
broadcast(&cmd_txs, TpCommand::Prefill { tokens: req.prompt_tokens.clone(), slot });
let logits = rc.model.forward_prefill_paged(&req.prompt_tokens, slot, &mut rc.cache);
broadcast(
&cmd_txs,
TpCommand::Prefill {
tokens: req.prompt_tokens.clone(),
slot,
},
);
let logits = rc
.model
.forward_prefill_paged(&req.prompt_tokens, slot, &mut rc.cache);
wait_acks(&ack_rx);
let mut gen_ids: Vec<u32> = Vec::new();
let mut next = pick(&logits, &req.sampling, &gen_ids);
@@ -204,8 +304,15 @@ pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
break "length";
}
let pos = rc.cache.seq_len(slot);
broadcast(&cmd_txs, TpCommand::Decode { tokens: vec![next], positions: vec![pos], slots: vec![slot] });
let logits = rc.model.forward_decode_paged(&[next], &[pos], &[slot], &mut rc.cache);
broadcast(
&cmd_txs,
TpCommand::Decode {
tokens: vec![next],
positions: vec![pos],
slots: vec![slot],
},
);
let logits = rank_decode(&mut rc, &[next], &[pos], &[slot]);
wait_acks(&ack_rx);
next = pick(&logits, &req.sampling, &gen_ids);
gen_ids.push(next);
@@ -215,9 +322,14 @@ pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
let tail = tokenizer.flush_decode_stream(&mut decode_buf);
if !tail.is_empty() {
let _ = req.sender.blocking_send(GenerateEvent::Token { id: next, text: tail });
let _ = req.sender.blocking_send(GenerateEvent::Token {
id: next,
text: tail,
});
}
let _ = req.sender.blocking_send(GenerateEvent::Done { finish_reason: finish.to_string() });
let _ = req.sender.blocking_send(GenerateEvent::Done {
finish_reason: finish.to_string(),
});
broadcast(&cmd_txs, TpCommand::Free(slot));
rc.cache.free_sequence(slot);
@@ -234,6 +346,8 @@ fn emit_text(tokenizer: &Tokenizer, req: &GenerateRequest, token_id: u32, buf: &
}
let text = tokenizer.decode_token_stream(token_id, buf);
if !text.is_empty() {
let _ = req.sender.blocking_send(GenerateEvent::Token { id: token_id, text });
let _ = req
.sender
.blocking_send(GenerateEvent::Token { id: token_id, text });
}
}

View File

@@ -43,18 +43,30 @@ pub trait TensorDType: Copy + Send + Sync + 'static {
impl TensorDType for f32 {
const DTYPE: DType = DType::F32;
fn to_f64(self) -> f64 { self as f64 }
fn from_f64(v: f64) -> Self { v as f32 }
fn to_f64(self) -> f64 {
self as f64
}
fn from_f64(v: f64) -> Self {
v as f32
}
}
impl TensorDType for f16 {
const DTYPE: DType = DType::F16;
fn to_f64(self) -> f64 { self.to_f32() as f64 }
fn from_f64(v: f64) -> Self { f16::from_f32(v as f32) }
fn to_f64(self) -> f64 {
self.to_f32() as f64
}
fn from_f64(v: f64) -> Self {
f16::from_f32(v as f32)
}
}
impl TensorDType for bf16 {
const DTYPE: DType = DType::BF16;
fn to_f64(self) -> f64 { self.to_f32() as f64 }
fn from_f64(v: f64) -> Self { bf16::from_f32(v as f32) }
fn to_f64(self) -> f64 {
self.to_f32() as f64
}
fn from_f64(v: f64) -> Self {
bf16::from_f32(v as f32)
}
}

View File

@@ -6,4 +6,4 @@ pub mod tensor;
pub use dtype::{DType, TensorDType};
pub use shape::Dims;
pub use storage::{Device, Storage};
pub use tensor::{register_gpu_contiguous, Tensor};
pub use tensor::{Tensor, register_gpu_contiguous};

View File

@@ -46,8 +46,16 @@ pub fn broadcast_shape(a: &[usize], b: &[usize]) -> Option<Dims> {
let ndim = a.len().max(b.len());
let mut result = SmallVec::with_capacity(ndim);
for i in 0..ndim {
let da = if i < ndim - a.len() { 1 } else { a[i - (ndim - a.len())] };
let db = if i < ndim - b.len() { 1 } else { b[i - (ndim - b.len())] };
let da = if i < ndim - a.len() {
1
} else {
a[i - (ndim - a.len())]
};
let db = if i < ndim - b.len() {
1
} else {
b[i - (ndim - b.len())]
};
if da == db {
result.push(da);
} else if da == 1 {
@@ -100,8 +108,14 @@ mod tests {
#[test]
fn test_broadcast_shape() {
assert_eq!(broadcast_shape(&[3, 1], &[1, 4]).unwrap().as_slice(), &[3, 4]);
assert_eq!(broadcast_shape(&[2, 3, 4], &[4]).unwrap().as_slice(), &[2, 3, 4]);
assert_eq!(
broadcast_shape(&[3, 1], &[1, 4]).unwrap().as_slice(),
&[3, 4]
);
assert_eq!(
broadcast_shape(&[2, 3, 4], &[4]).unwrap().as_slice(),
&[2, 3, 4]
);
assert_eq!(broadcast_shape(&[1], &[5, 3]).unwrap().as_slice(), &[5, 3]);
assert!(broadcast_shape(&[3], &[4]).is_none());
}
@@ -109,6 +123,9 @@ mod tests {
#[test]
fn test_broadcast_strides() {
// [3,1] with strides [1,1] broadcast to [3,4]
assert_eq!(broadcast_strides(&[3, 1], &[1, 1], &[3, 4]).as_slice(), &[1, 0]);
assert_eq!(
broadcast_strides(&[3, 1], &[1, 1], &[3, 4]).as_slice(),
&[1, 0]
);
}
}

View File

@@ -33,8 +33,20 @@ impl Tensor {
// --- Creation ---
/// Create a tensor from raw components (for advanced use like GPU KV cache).
pub fn from_storage(storage: Storage, shape: Dims, strides: Dims, offset: usize, dtype: DType) -> Self {
Self { storage, shape, strides, offset, dtype }
pub fn from_storage(
storage: Storage,
shape: Dims,
strides: Dims,
offset: usize,
dtype: DType,
) -> Self {
Self {
storage,
shape,
strides,
offset,
dtype,
}
}
pub fn from_slice<T: TensorDType>(data: &[T], shape: &[usize]) -> Self {
@@ -60,7 +72,10 @@ impl Tensor {
data.len(),
numel * dtype.size_bytes(),
"raw bytes length {} != expected {} (numel={} * elem_size={})",
data.len(), numel * dtype.size_bytes(), numel, dtype.size_bytes()
data.len(),
numel * dtype.size_bytes(),
numel,
dtype.size_bytes()
);
Self {
storage: Storage::cpu(data.to_vec()),
@@ -112,14 +127,28 @@ impl Tensor {
// --- Properties ---
pub fn shape(&self) -> &[usize] { &self.shape }
pub fn strides(&self) -> &[usize] { &self.strides }
pub fn dtype(&self) -> DType { self.dtype }
pub fn ndim(&self) -> usize { self.shape.len() }
pub fn numel(&self) -> usize { shape::num_elements(&self.shape) }
pub fn offset(&self) -> usize { self.offset }
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn strides(&self) -> &[usize] {
&self.strides
}
pub fn dtype(&self) -> DType {
self.dtype
}
pub fn ndim(&self) -> usize {
self.shape.len()
}
pub fn numel(&self) -> usize {
shape::num_elements(&self.shape)
}
pub fn offset(&self) -> usize {
self.offset
}
pub fn device(&self) -> Device { self.storage.device() }
pub fn device(&self) -> Device {
self.storage.device()
}
pub fn is_contiguous(&self) -> bool {
shape::is_contiguous(&self.shape, &self.strides)
@@ -193,7 +222,11 @@ impl Tensor {
shape::contiguous_strides(&new_shape)
} else {
let mut s = self.strides.clone();
let stride_val = if dim < self.strides.len() { self.strides[dim] } else { 1 };
let stride_val = if dim < self.strides.len() {
self.strides[dim]
} else {
1
};
s.insert(dim, stride_val);
s
};
@@ -230,7 +263,12 @@ impl Tensor {
let ndim = self.ndim();
let mut idx = vec![0usize; ndim];
for flat in 0..numel {
let src_offset = self.offset + idx.iter().zip(self.strides.iter()).map(|(i, s)| i * s).sum::<usize>();
let src_offset = self.offset
+ idx
.iter()
.zip(self.strides.iter())
.map(|(i, s)| i * s)
.sum::<usize>();
let src_byte_offset = src_offset * elem_size;
let dst_byte_offset = flat * elem_size;
dst[dst_byte_offset..dst_byte_offset + elem_size]
@@ -261,7 +299,10 @@ impl Tensor {
}
// Transfer the raw storage (preserving strides/offset).
// Non-contiguous layout is preserved — the user can call contiguous() after.
let new_storage = self.storage.to_device(device).expect("device transfer failed");
let new_storage = self
.storage
.to_device(device)
.expect("device transfer failed");
Self {
storage: new_storage,
shape: self.shape.clone(),
@@ -310,14 +351,20 @@ impl Tensor {
}
}
pub fn storage(&self) -> &Storage { &self.storage }
pub fn storage(&self) -> &Storage {
&self.storage
}
}
impl std::fmt::Debug for Tensor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f, "Tensor(shape={:?}, dtype={}, device={}, contiguous={})",
self.shape.as_slice(), self.dtype, self.device(), self.is_contiguous()
f,
"Tensor(shape={:?}, dtype={}, device={}, contiguous={})",
self.shape.as_slice(),
self.dtype,
self.device(),
self.is_contiguous()
)
}
}

View File

@@ -32,7 +32,11 @@ fn test_zeros_and_ones() {
#[test]
fn test_bf16_tensor() {
let data: Vec<bf16> = vec![bf16::from_f32(1.0), bf16::from_f32(2.5), bf16::from_f32(-3.0)];
let data: Vec<bf16> = vec![
bf16::from_f32(1.0),
bf16::from_f32(2.5),
bf16::from_f32(-3.0),
];
let t = Tensor::from_slice(&data, &[3]);
assert_eq!(t.dtype(), DType::BF16);
let out = t.as_slice::<bf16>();

View File

@@ -95,11 +95,15 @@ impl Tokenizer {
let (a_str, b_str) = match entry {
MergeEntry::Str(s) => {
let parts: Vec<&str> = s.splitn(2, ' ').collect();
if parts.len() != 2 { continue; }
if parts.len() != 2 {
continue;
}
(parts[0].to_string(), parts[1].to_string())
}
MergeEntry::Pair(v) => {
if v.len() != 2 { continue; }
if v.len() != 2 {
continue;
}
(v[0].clone(), v[1].clone())
}
};
@@ -174,7 +178,10 @@ impl Tokenizer {
if byte_fallback {
Regex::new(r"[\p{L}\p{N}]+|[^\s\p{L}\p{N}]|\s+").unwrap()
} else {
Regex::new(r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+").unwrap()
Regex::new(
r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+",
)
.unwrap()
}
}
}
@@ -262,7 +269,9 @@ impl Tokenizer {
// BPE merges
loop {
if token_ids.len() < 2 { break; }
if token_ids.len() < 2 {
break;
}
let mut best_rank = usize::MAX;
let mut best_idx = 0;
for i in 0..token_ids.len() - 1 {
@@ -273,12 +282,15 @@ impl Tokenizer {
}
}
}
if best_rank == usize::MAX { break; }
if best_rank == usize::MAX {
break;
}
let merged_bytes = [
self.decoder[token_ids[best_idx] as usize].as_slice(),
self.decoder[token_ids[best_idx + 1] as usize].as_slice(),
].concat();
]
.concat();
let merged_id = *self.encoder.get(&merged_bytes).unwrap_or_else(|| {
panic!("merged token not in vocab");
});
@@ -389,14 +401,13 @@ fn unicode_to_byte(c: char) -> u8 {
m
});
*map.get(&(c as u32)).unwrap_or_else(|| {
panic!("unmapped unicode char U+{:04X} in tokenizer", c as u32)
})
*map.get(&(c as u32))
.unwrap_or_else(|| panic!("unmapped unicode char U+{:04X} in tokenizer", c as u32))
}
#[cfg(test)]
mod tests {
use super::{take_valid_utf8, Tokenizer};
use super::{Tokenizer, take_valid_utf8};
#[test]
fn qwen_added_tokens_are_indivisible_and_im_end_is_eos() {

View File

@@ -87,6 +87,17 @@ __global__ void add_bf16_kernel(const __nv_bfloat16* a, const __nv_bfloat16* b,
if (idx < n) out[idx] = __float2bfloat16(__bfloat162float(a[idx]) + __bfloat162float(b[idx]));
}
// Row-broadcast bias add: out[r, c] = x[r, c] + bias[c]
__global__ void bias_add_2d_bf16_kernel(
const __nv_bfloat16* __restrict__ x, const __nv_bfloat16* __restrict__ bias,
__nv_bfloat16* __restrict__ out, int rows, int cols
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= rows * cols) return;
float v = __bfloat162float(x[idx]) + __bfloat162float(bias[idx % cols]);
out[idx] = __float2bfloat16(v);
}
// Element-wise mul: out = a * b
__global__ void mul_f32_kernel(const float* a, const float* b, float* out, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
@@ -159,6 +170,14 @@ void launch_add_bf16(const void* a, const void* b, void* out, int n, void* strea
(const __nv_bfloat16*)a, (const __nv_bfloat16*)b, (__nv_bfloat16*)out, n);
CUDA_CHECK_LAST_ERROR();
}
void launch_bias_add_2d_bf16(const void* x, const void* bias, void* out, int rows, int cols, void* stream) {
int n = rows * cols;
int block = 256;
int grid = (n + block - 1) / block;
bias_add_2d_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)x, (const __nv_bfloat16*)bias, (__nv_bfloat16*)out, rows, cols);
CUDA_CHECK_LAST_ERROR();
}
void launch_mul_f32(const void* a, const void* b, void* out, int n, void* stream) {
int block = 256;
int grid = (n + block - 1) / block;

View File

@@ -93,11 +93,9 @@ __global__ void moe_replicate_bf16_kernel(
int total = local_experts * num_tokens * hidden;
if (idx >= total) return;
int expert = idx / (num_tokens * hidden);
int remainder = idx % (num_tokens * hidden);
// x_rep[expert, token, dim] = x[token, dim]
x_rep[idx] = x[remainder];
(void)expert; // suppress unused warning
}
// ============================================================

254
csrc/moe/moe_sparse.cu Normal file
View File

@@ -0,0 +1,254 @@
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cstdint>
#include "../common.cuh"
// ============================================================
// Sparse MoE decode GEMVs — compute ONLY the routed experts.
//
// The dense path replicates x across all local experts and runs a
// batched GEMM, reading every expert's weights per token. Decode is
// memory-bound, so reading only the top-k routed experts' weights
// (~2 of 16 local on average at TP=2) is a ~8x byte reduction.
//
// Each block handles one (token, slot) pair's tile of output columns.
// It reads topk_ids[token, slot] from device memory (no host sync),
// and exits early if the expert is not owned by this rank. The early
// return is BLOCK-UNIFORM (every thread sees the same topk_ids value
// and returns before the shared-memory staging + __syncthreads), so
// it is safe — unlike the divergent-return bug fixed in gemv.cu.
//
// Outputs for non-local slots are NEVER written (uninitialized memory,
// possibly NaN bit patterns). Downstream consumers must SKIP non-local
// slots rather than multiply by zero (NaN * 0 = NaN).
//
// Per-expert weight scale and bias are fused into the epilogue:
// y[t, slot, n] = acc * w_scale[lid] + bias[lid, n]
// which matches the dense path's GEMM -> moe_bias_add_3d sequence.
//
// Activation addressing (x_per_slot):
// gate_up: all slots of a token share x[token, :] (x_per_slot=0)
// down: each slot has its own activation row
// x[token * top_k + slot, :] (x_per_slot=1)
// ============================================================
#define SPARSE_TILE_N 8 // output columns per block (= warps per block)
// Weights FP8 E4M3 [local_experts, N, K], activations BF16 (W8A16).
// Decode is memory-bound (~2 FLOP/byte), so dequant-in-registers GEMV
// loses nothing to tensor cores and skips activation quantization.
__global__ void moe_sparse_gemv_fp8_bf16_kernel(
const __nv_bfloat16* __restrict__ x, // [T, K] or [T*top_k, K]
const __nv_fp8_e4m3* __restrict__ w, // [local_experts, N, K]
const float* __restrict__ w_scales, // [local_experts]
const __nv_bfloat16* __restrict__ bias, // [local_experts, N]
const int* __restrict__ topk_ids, // [T, top_k] global expert ids
__nv_bfloat16* __restrict__ y, // [T, top_k, N]
int N, int K, int top_k,
int expert_start, int local_experts,
int x_per_slot
) {
int token = blockIdx.z;
int slot = blockIdx.y;
int eid = topk_ids[token * top_k + slot];
int lid = eid - expert_start;
if (lid < 0 || lid >= local_experts) return; // block-uniform: safe
extern __shared__ float xs[]; // [K] activation row as float
const __nv_bfloat16* xrow =
x + (long long)(x_per_slot ? token * top_k + slot : token) * K;
for (int i = threadIdx.x; i < K; i += blockDim.x) {
xs[i] = __bfloat162float(xrow[i]);
}
__syncthreads();
int n = blockIdx.x * SPARSE_TILE_N + (threadIdx.x >> 5);
if (n >= N) return; // after __syncthreads: safe
int lane = threadIdx.x & 31;
// One warp per output column; uint4 = 16 FP8 weights per lane, the
// warp covers 512 contiguous bytes per iteration (coalesced).
const uint8_t* wrow = (const uint8_t*)w + ((long long)lid * N + n) * K;
float acc = 0.0f;
for (int i = lane; i < (K >> 4); i += 32) {
uint4 packed = *(const uint4*)(wrow + (long long)i * 16);
const __nv_fp8_e4m3* pw = (const __nv_fp8_e4m3*)&packed;
const float* xk = xs + i * 16;
#pragma unroll
for (int j = 0; j < 16; j++) {
acc += xk[j] * float(pw[j]);
}
}
#pragma unroll
for (int o = 16; o > 0; o >>= 1) {
acc += __shfl_down_sync(0xffffffffu, acc, o);
}
if (lane == 0) {
float v = acc * w_scales[lid]
+ __bfloat162float(bias[(long long)lid * N + n]);
y[((long long)token * top_k + slot) * N + n] = __float2bfloat16(v);
}
}
// MXFP4 W4A16 variant: packed E2M1 nibbles + per-32 UE8M0 block scale,
// same structure as batched_gemv_mxfp4_bf16_kernel but expert-indexed
// via topk_ids and with fused per-expert bias.
#define MXFP4_BLOCK 32
__device__ __constant__ float kSparseFp4Levels[8] =
{0.f, 0.5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f};
__device__ __forceinline__ float sparse_fp4_to_float(uint8_t code) {
float mag = kSparseFp4Levels[code & 0x7];
return (code & 0x8) ? -mag : mag;
}
__global__ void moe_sparse_gemv_mxfp4_bf16_kernel(
const __nv_bfloat16* __restrict__ x, // [T, K] or [T*top_k, K]
const uint8_t* __restrict__ w_packed, // [local_experts, N, K/2]
const uint8_t* __restrict__ w_scales, // [local_experts, N, K/32]
const __nv_bfloat16* __restrict__ bias, // [local_experts, N]
const int* __restrict__ topk_ids, // [T, top_k]
__nv_bfloat16* __restrict__ y, // [T, top_k, N]
int N, int K, int top_k,
int expert_start, int local_experts,
int x_per_slot
) {
int token = blockIdx.z;
int slot = blockIdx.y;
int eid = topk_ids[token * top_k + slot];
int lid = eid - expert_start;
if (lid < 0 || lid >= local_experts) return; // block-uniform: safe
extern __shared__ float xs[];
const __nv_bfloat16* xrow =
x + (long long)(x_per_slot ? token * top_k + slot : token) * K;
for (int i = threadIdx.x; i < K; i += blockDim.x) {
xs[i] = __bfloat162float(xrow[i]);
}
__syncthreads();
int n = blockIdx.x * SPARSE_TILE_N + (threadIdx.x >> 5);
if (n >= N) return;
int lane = threadIdx.x & 31;
int nblk = K / MXFP4_BLOCK;
const uint8_t* wp = w_packed + ((long long)lid * N + n) * (K >> 1);
const uint8_t* ws = w_scales + ((long long)lid * N + n) * nblk;
float acc = 0.0f;
for (int blk = lane; blk < nblk; blk += 32) {
float scale = exp2f((float)((int)ws[blk] - 127));
uint4 packed = *(const uint4*)(wp + (long long)blk * 16); // 32 nibbles
const uint8_t* pb = (const uint8_t*)&packed;
const float* xk = xs + blk * MXFP4_BLOCK;
#pragma unroll
for (int i = 0; i < 16; i++) {
uint8_t b = pb[i];
acc += xk[2 * i] * (sparse_fp4_to_float(b & 0xF) * scale);
acc += xk[2 * i + 1] * (sparse_fp4_to_float(b >> 4) * scale);
}
}
#pragma unroll
for (int o = 16; o > 0; o >>= 1) {
acc += __shfl_down_sync(0xffffffffu, acc, o);
}
if (lane == 0) {
float v = acc + __bfloat162float(bias[(long long)lid * N + n]);
y[((long long)token * top_k + slot) * N + n] = __float2bfloat16(v);
}
}
// Weighted sum over the slot axis: out[t, d] = sum over local slots of
// topk_weights[t, k] * down[t, k, d]. Non-local slots hold uninitialized
// memory and are SKIPPED (not multiplied by zero).
__global__ void moe_weighted_sum_sparse_bf16_kernel(
const __nv_bfloat16* __restrict__ down, // [T, top_k, hidden]
const int* __restrict__ topk_ids, // [T, top_k]
const float* __restrict__ topk_weights, // [T, top_k]
__nv_bfloat16* __restrict__ out, // [T, hidden]
int num_tokens, int hidden, int top_k,
int expert_start, int local_experts
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = num_tokens * hidden;
if (idx >= total) return;
int token = idx / hidden;
int dim = idx % hidden;
float sum = 0.0f;
for (int k = 0; k < top_k; k++) {
int lid = topk_ids[token * top_k + k] - expert_start;
if (lid >= 0 && lid < local_experts) {
float w = topk_weights[token * top_k + k];
float v = __bfloat162float(
down[((long long)token * top_k + k) * hidden + dim]);
sum += w * v;
}
}
out[idx] = __float2bfloat16(sum);
}
extern "C" {
void launch_moe_sparse_gemv_fp8_bf16(
const void* x, const void* w, const void* w_scales, const void* bias,
const void* topk_ids, void* y,
int num_tokens, int N, int K, int top_k,
int expert_start, int local_experts, int x_per_slot,
void* stream
) {
dim3 grid((N + SPARSE_TILE_N - 1) / SPARSE_TILE_N, top_k, num_tokens);
int block = SPARSE_TILE_N * 32;
size_t smem = (size_t)K * sizeof(float);
moe_sparse_gemv_fp8_bf16_kernel<<<grid, block, smem, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)x, (const __nv_fp8_e4m3*)w,
(const float*)w_scales, (const __nv_bfloat16*)bias,
(const int*)topk_ids, (__nv_bfloat16*)y,
N, K, top_k, expert_start, local_experts, x_per_slot
);
CUDA_CHECK_LAST_ERROR();
}
void launch_moe_sparse_gemv_mxfp4_bf16(
const void* x, const void* w_packed, const void* w_scales, const void* bias,
const void* topk_ids, void* y,
int num_tokens, int N, int K, int top_k,
int expert_start, int local_experts, int x_per_slot,
void* stream
) {
dim3 grid((N + SPARSE_TILE_N - 1) / SPARSE_TILE_N, top_k, num_tokens);
int block = SPARSE_TILE_N * 32;
size_t smem = (size_t)K * sizeof(float);
moe_sparse_gemv_mxfp4_bf16_kernel<<<grid, block, smem, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)x, (const uint8_t*)w_packed,
(const uint8_t*)w_scales, (const __nv_bfloat16*)bias,
(const int*)topk_ids, (__nv_bfloat16*)y,
N, K, top_k, expert_start, local_experts, x_per_slot
);
CUDA_CHECK_LAST_ERROR();
}
void launch_moe_weighted_sum_sparse_bf16(
const void* down, const void* topk_ids, const void* topk_weights,
void* out,
int num_tokens, int hidden, int top_k,
int expert_start, int local_experts,
void* stream
) {
int total = num_tokens * hidden;
int block = 256;
int grid = (total + block - 1) / block;
moe_weighted_sum_sparse_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)down,
(const int*)topk_ids, (const float*)topk_weights,
(__nv_bfloat16*)out,
num_tokens, hidden, top_k, expert_start, local_experts
);
CUDA_CHECK_LAST_ERROR();
}
}

View File

@@ -1748,6 +1748,27 @@ Text → Tokenizer → Text Tokens ────────────→
---
## 实际进展记录(与原计划的分叉,2026-06 更新)
Phase 017 按计划完成。Phase 18 起实际路线偏离了上面的原计划
(speculative decoding 与多模态推迟),实际走向是 MoE + 量化 + 稀疏化:
| 实际 Phase | 内容 | 文档 |
|---|---|---|
| 18 | Pipeline Parallelism(PP=2/4) | `18-pipeline-parallelism.md`、`benchmarks/pp-sweep.md` |
| 19 | **gpt-oss-20b MoE**:harmony 格式、attention sinks + sliding window、YaRN;两个 CUDA bug 实战(prefill sinks NaN、GEMV 未初始化 smem);GSM8K 94.5% 对齐 llama.cpp;FP8 W8A8 / MXFP4 W4A16 量化 | `19-gpt-oss-moe.md`、`benchmarks/{fp8-quantization,mxfp4-and-llama-decode}.md` |
| 20 | **稀疏 top-k MoE decode**:只算被路由的专家,decode 13.9→7.0ms,TP=2 下 decode/TTFT 全面快于 llama.cpp 同配置;gpt-oss 单卡 serving | `20-sparse-moe.md`、`benchmarks/sparse-moe.md` |
| 21 | **decode CUDA Graph + GPU argmax**:整个 decode step 录成一个图回放(thread-local launch stream、retained-warmup 分配策略、NCCL capture);greedy 采样换 GPU argmax。TPOT 7.5→5.9ms(TP=1)/ 5.8ms(TP=2);TP=2 全面领先 llama(1.26-1.47×),TP=1 差距 2.5×→2.0× | `21-cuda-graph-decode.md` |
**下一步候选(按预期收益排序):**
| 候选 Phase | 内容 | 预期 |
|---|---|---|
| 22 | **非专家权重量化**:qkv/o + lm_head(1.16GB/token)仍是 BF16 | TPOT 再省 ~1.5ms |
| 23 | **稀疏 prefill**(按专家 permute + grouped GEMM) | 长 prompt TTFT 51-75 → ~30ms |
| 24 | server 侧 harmony channel 分离(`reasoning_content` 流式输出,对齐 llama-server 行为) | API 易用性 |
| — | Speculative Decoding、多模态(原 16/19) | 推迟 |
## 里程碑总结
| 里程碑 | Phase | 验收标准 |
@@ -1757,7 +1778,9 @@ Text → Tokenizer → Text Tokens ────────────→
| ③ E2E API | 13 | HTTP streaming API, Python OpenAI SDK 可调用, 10 并发正确 |
| ④ 性能达标 | 15 | throughput >= 50% vLLM, profiling 报告完成 |
| ⑤ 多卡推理 | 17 | TP=2/4 同组 GPU 推理正确, scaling benchmark 完成 |
| ⑥ 多模态 | 19 | 图片输入 → 文字回答, API 端到端 |
| ⑥ MoE 模型(实际) | 19 | gpt-oss-20b 端到端正确, GSM8K 与 llama.cpp 持平 ✅ |
| ⑦ 性能反超(实际) | 20 | 同配置 decode 快于 llama.cpp(TP=2 达成;单卡是 Phase 21+ 目标) ✅ |
| ⑧ 多模态 | 推迟 | 图片输入 → 文字回答, API 端到端 |
## 外部依赖清单

118
docs/19-gpt-oss-moe.md Normal file
View File

@@ -0,0 +1,118 @@
# Phase 19: gpt-oss-20b — MoE 模型支持与两次 CUDA 调试实战
> 目标:支持 OpenAI gpt-oss-20b(32 专家 top-4 MoE),GSM8K 精度对齐 llama.cpp,
> 并以此为载体做 FP8 / MXFP4 量化。本文档事后整理,重点放在**踩过的坑**:
> 两个教科书级的 CUDA bug 排查过程比结论本身更有学习价值。
>
> 后续:`docs/20-sparse-moe.md`(稀疏化),benchmark 数据见
> `docs/benchmarks/{fp8-quantization,mxfp4-and-llama-decode,sparse-moe}.md`。
## 1. 模型架构(与 Qwen3 的差异点)
gpt-oss-20b(`config.json`,已在 dash5 验证):
| 项 | 值 | 说明 |
|---|---|---|
| layers / hidden | 24 / 2880 | hidden **不是** 128 的倍数的来源(2880 = 22.5×128) |
| heads | 64 Q / 8 KV,head_dim **64** | head_dim ≠ hidden/heads(64×64=4096>2880),GQA n_rep=8 |
| MoE | 32 experts,top-4,expert inter 2880 | router 是普通 Linear [2880→32] + bias |
| attention | **交替 sliding(128)/full**,layer 0 是 sliding | 每层带 **attention sinks**(每 head 一个可学习标量) |
| RoPE | YaRN(theta 150000, factor 32, orig 4096) | attn_factor = 0.1·ln(32)+1 乘在 cos/sin 上 |
| 激活 | clamp 后的 GLU | gate=gu[::2], up=gu[1::2](**交错**), gate≤7, up∈[-7,7], glu=gate·σ(1.702·gate), h=(up+1)·glu |
| 词表 | 201088 | EOS 是**列表** [200002,199999,200012] = `<|return|>`/`<|endoftext|>`/`<|call|>` |
| 其它 | attention_bias=true | q/k/v/o 全部带 bias(Qwen3 没有) |
**Harmony 对话格式**:gpt-oss 不是普通 chat template,输出分 channel
(`analysis`=思维链,`final`=正式回答),控制 token `<|start|>/<|channel|>/<|message|>/<|end|>`
三个坑:(1) system 消息必须含 `Reasoning:` 等 canonical 行,缺了模型 OOD、
channel 选择不稳定;(2) repetition penalty 会惩罚必须重复出现的控制 token,
导致模型只输出 analysis 不出 final(MoE 默认关掉);(3) 服务端要用多 EOS 判停。
## 2. MoE 前向(dense 版,Phase 20 之前)
```text
router GEMV → topk_softmax(GPU)→ moe_replicate(复制到全部本地专家)
→ batched GEMM gate_up → bias → GLU → batched GEMM down → bias
→ weighted_sum(只取 top-4)→ all-reduce
```
要点:top-k 的专家编号始终留在 GPU(`topk_ids`),host 不同步;
dense 的代价(每 token 读全部专家权重)在 Phase 20 用 sparse GEMV 解决。
TP 用 **expert parallelism**:rank r 拥有专家 [r·E/world, (r+1)·E/world),
weighted_sum 里按 `expert_start + local_experts` 过滤非本地命中,
all-reduce 把各 rank 的部分和加起来——这要求"跳过"语义而不是"乘 0"。
## 3. CUDA 调试实战 ①:prefill NaN(flash-attention sinks)
**症状**:长 prompt(≳192 token)prefill 后输出全 NaN → argmax 落在
token 201087(`max_by` 平局取最后)或 token 0(`!`)。短 prompt 完全正常。
**定位手法**:给每个 stage 加 NaN 检查(环境变量开关,事后移除),
二分出第一个出 NaN 的位置:layer-0 的 `flash_attention_sinks` 输出,
而它的 q/k 输入是干净的 → bug 在 kernel 内部。
**根因**:causal 跳过逻辑只剔除"完全在未来"的 kv tile;一个完全滑出
sliding window(128)的**过去** tile 仍被处理,所有 key 都被 mask 成 -inf
`row_max = -inf` → online softmax 里 `expf(-inf-(-inf)) = NaN`,
下一个有效 tile 的修正项 `0·NaN` 把整行毒掉。
**修复**:`row_max == -INFINITY` 的 tile 直接跳过(贡献为零)。
**教训**:online softmax 的"空 tile"是边界条件标配;decode kernel 早就
防了这个(`local_max==-INFINITY` guard),prefill kernel 漏了——
**同一逻辑的两份实现要做同样的边界测试**。触发阈值 ~192 token 解释了
"短测试全过、长对话必炸"的诡异表象。
## 4. CUDA 调试实战 ②:decode 间歇性乱码(GEMV 未初始化共享内存)
**症状**:同一 prompt ~70% 的运行在第二轮对话或长生成中突然输出
`!!!!`/token 201087/NaN logits,**间歇性** → 不是确定性逻辑错误,
是竞态或未初始化读。只有 gpt-oss 出问题,Qwen3 从不复现。
**定位**:逐 stage 检查,第一个出问题的是 decode 的 o_proj 输出
(maxabs≈1e33),输入干净 → M=1 的 GEMV kernel。
**根因**(`gemv.cu`):
```cuda
if (col >= N) return; // ← 在协作加载 x_shared 和 __syncthreads 之前!
...cooperative load + __syncthreads()...
```
`N % 128 != 0` 时,最后一个 block 的越界线程提前退出,**没参与**
共享内存装载;在界线程读到未初始化的 smem(且 `__syncthreads` 在有线程
已退出时是 UB)。命中条件:n=2880 的矩阵(o_proj、MoE gate_up/down)——
2880 % 128 ≠ 0;而 Qwen3 所有维度都是 128 对齐的,**所以"只有 gpt-oss
不稳定"**。q/k/v(4096)、lm_head(201088)对齐,幸免。
**修复**:所有线程先完成装载 + barrier,`col >= N` 检查移到 syncthreads
**之后**
**教训**:`__syncthreads()` 之前的任何 early-return 必须是 **block-uniform**
的。Phase 20 的 sparse GEMV 专门遵守了这条(整个 block 基于同一个
`topk_ids` 值统一退出,发生在装载之前)。
**修复后的验证**:GSM8K 全量 1319 题,xserv 94.5% vs llama.cpp 94.4%
——统计上同一水平,证明两个 kernel bug 就是之前 55% vs 95% 差距的全部原因。
## 5. 量化(详见 benchmark 文档)
- **FP8 W8A8**(`tools/quantize_fp8.py`):per-expert 标量 scale,权重转置
存 [E,N,K] 喂 cuBLASLt(Blackwell 要求 transA=T)。两个性能坑:
(1) 每次调用重建 plan + 跑 heuristic → 比 BF16 还慢,修复 = per-shape
plan cache;(2) 逐专家发射 ~768 个小 GEMM,修复 = 单条 strided-batched
调用 + 把 scale 移到融合的 post-scale kernel。最终 1.41× vs BF16。
- **MXFP4 W4A16**(`tools/quantize_mxfp4.py`):E2M1 + per-32 UE8M0 块 scale,
13GB 模型,贪心输出与 BF16 逐字一致,但手写 dequant-GEMV 打不过
cuBLASLt FP8(带宽效率差),定位为省显存方案。
- 检测方式:safetensors 的 dtype/scale 秩自动识别,loader 无需配置。
## 6. 本阶段的工具沉淀
- `bench-gpt-oss`:in-process 推理 + `--forced`(teacher-forced prefill
top-1)/`--forced-decode`(沿参考轨迹逐位置 top-1)——分离"前向算错"
和"贪心轨迹分叉"的利器。
- `tools/eval_gsm8k_fast.py`(持久 xserv-chat 管道)、
`tools/xserv_vs_llama.py`(warm-server 同机对打,计入 llama 的
reasoning_content)。
- 经验:**贪心解码不是逐位可复现的**(cuBLAS 非确定性会翻转后段 argmax),
多卡正确性要用"单卡×2 + 多卡×2 互相比",精度要用基准集而不是逐字 diff。

160
docs/20-sparse-moe.md Normal file
View File

@@ -0,0 +1,160 @@
# Phase 20: Sparse MoE Decode — 只算被路由到的专家
> 目标:消除 dense MoE 的无效权重读取,decode TPOT 追上并超过 llama.cpp。
> 前置:Phase 19(gpt-oss MoE 正确性)、FP8 W8A8 / MXFP4 W4A16 量化
> (见 `docs/benchmarks/fp8-quantization.md`、`docs/benchmarks/mxfp4-and-llama-decode.md`)。
## 1. 现状:dense MoE 在浪费什么
gpt-oss-20b 是 32 专家 top-4 的 MoE:router 给每个 token 选 4 个专家,
理论上每 token 只需要读 4/32 = 12.5% 的专家权重。但 `moe_forward`
(`crates/xserv-model/src/gpt_oss.rs`)目前是 **dense** 实现:
```text
1. router GEMV [T, 2880] → [T, 32]
2. topk_softmax (GPU) → topk_ids [T,4], topk_weights [T,4]
3. moe_replicate x 复制 16 份 → [16, T, 2880] ← 浪费开始
4. batched GEMM gate_up 全部 16 个本地专家都算 ← 读 16 份权重
5. bias + GLU
6. batched GEMM down 全部 16 个本地专家都算 ← 读 16 份权重
7. bias
8. moe_weighted_sum 只挑出 top-4 加权求和,其余 12 个全部丢弃
9. all-reduce
```
为什么当初这么写:batched GEMM(cuBLAS strided-batched)要求规则的
`[E, T, K]` 形状;top-4 的专家编号在 **GPU** 上(`topk_ids`),host 不知道
该挑哪几个,挑了形状也不规则。dense 是"先把正确性做出来"的合理起点,
但每 token 把 16 个专家的权重从 HBM 全部读一遍。
### 字节账本(decode,每 token,TP=2 每卡 16 个本地专家)
每层每专家:gate_up `[2880, 5760]` + down `[2880, 2880]` ≈ 24.9 M 参数。
| 方案 | 每卡每 token 专家字节 | 相对量 |
|---|---|---|
| xserv dense FP8(现状) | 16 × 24.9 MB × 24 层 ≈ **9.6 GB** | 1× |
| xserv sparse FP8(本阶段) | ~2 × 24.9 MB × 24 层 ≈ **1.2 GB** | 1/8 |
| llama.cpp sparse MXFP4 | ~2 × 12.5 MB × 24 层 ≈ **0.6 GB** | 1/16 |
(top-4 均匀散落在 2 张卡上,期望每卡 2 个命中;严格说每层取的是
两卡命中数的 max,期望 ≈ 2.6,仍是 ~6-8× 的节省。)
实测旁证:FP8 dense TP=2 TPOT 13.1 ms,其中专家 GEMM ≈ 9.6 GB ÷ ~1 TB/s
≈ 9.5 ms,其余(attention、qkv/o、lm_head、48 次 PCIe all-reduce)≈ 3.5 ms。
**专家权重读取占 TPOT 的 ~3/4,这就是与 llama.cpp(6.6 ms)的全部差距。**
## 2. Roofline:M=1 时为什么"省字节 = 省时间"
decode 的 GEMV(M=1)每读 1 字节 FP8 权重只做 2 FLOP(乘加)。
RTX 5090:HBM ~1.8 TB/s,BF16 算力 ~210 TFLOPS —— 算强比(arithmetic
intensity)需要 ~100 FLOP/byte 才能喂饱算力,GEMV 只有 2。结论:
1. **decode 完全 memory-bound**,tensor core 帮不上忙 → 手写 W8A16 GEMV
(权重 FP8、激活保持 BF16)不会输给 cuBLASLt 的 W8A8 tensor-core GEMM,
还省掉激活量化 kernel,精度更好(激活不再有量化误差)。
2. 优化只有一个方向:**少读字节**。sparse(×8)与 4-bit(×2)正交,
可叠加。本阶段先做 sparse,FP8 与 MXFP4 两种权重格式都支持。
## 3. Sparse 设计:让 kernel 自己按 topk_ids 索引权重
关键观察:`topk_ids` 本来就在 GPU 上。不需要 host 知道选了谁 ——
**让 GEMV kernel 的每个 block 自己读 `topk_ids[token, slot]`,
直接寻址到对应专家的权重**,不命中本卡就整块退出。零 host 同步,
管线保持完全异步(这是之前排查过的:decode 循环无 per-layer sync)。
新数据流(`num_tokens ≤ 8` 时启用):
```text
x [T, 2880]
├─ router → topk_ids/weights [T, 4] (不变)
├─ sparse GEMV gate_up → [T, 4, 5760] bias 已融合,非本地 slot 不写
├─ GLU → [T*4, 2880]
├─ sparse GEMV down → [T, 4, 2880] bias 已融合,非本地 slot 不写
└─ weighted_sum_sparse → [T, 2880] 只累加本地 slot
all-reduce (不变)
```
`moe_replicate` 和独立的 bias kernel 在 sparse 路径下消失;FP8 路径还省掉
`quantize_bf16_to_fp8_rowwise`
### Kernel 设计(`csrc/moe/moe_sparse.cu`)
`moe_sparse_gemv_{fp8,mxfp4}_bf16_kernel`:
- **grid = (N/8, top_k, tokens)**,block = 8 warp × 32 lane。
每个 block 负责一个 (token, slot) 的 8 个输出列,**一个 warp 算一个输出**。
- block 先读 `eid = topk_ids[token*top_k + slot]`,折算 `lid = eid - expert_start`;
不在 `[0, local_experts)` 就整块 return。
- 命中的 block 把激活行(K=2880 个 BF16 → float)协作搬进 shared memory
(11.25 KB),`__syncthreads()`,然后每 warp 沿 K 维做点积:
每 lane 一次 `uint4` 读 16 字节权重(FP8 = 16 个权重,MXFP4 = 32 个 nibble),
warp 内 32 lane 连续 → 512B coalesced 事务。
- epilogue(lane 0):`y = acc * w_scale[lid] + bias[lid, n]` —— per-expert
scale 和 bias 都融合在这里,与 dense 路径的"GEMM → bias add → 路由加权"
语义逐位等价(HF 参考实现也是先加 bias 再乘路由权重)。
- gate_up 与 down 共用同一个 kernel,用 `x_per_slot` 区分激活寻址:
gate_up 时 4 个 slot 共享 `x[token]`;down 时各读自己的 `act[token*4+slot]`
### 两个容易写错的安全点
1. **early-return 必须 block-uniform。** Phase 19 的 GEMV 垃圾输出 bug
(commit `3b9e32e`)正是"部分线程在 `__syncthreads()` 之前 return"导致
读未初始化 shared memory。这里的 return 发生在 smem 装载**之前**,且整个
block 基于同一个 `topk_ids` 值统一退出 —— 没有 divergence,合法且安全。
2. **weighted-sum 对非本地 slot 必须"跳过",不能"乘 0"。** 非本地 slot 的
GEMV 输出从未被写入(未初始化显存,可能是 NaN 位型),GLU 也会在上面算出
垃圾。`NaN × 0 = NaN`,所以求和 kernel 用 `if (local) sum += w*v` 跳过,
垃圾永远不进入数据流(dense 路径的 `moe_weighted_sum` 同理)。
## 4. 为什么 prefill 保持 dense
dense batched GEMM 把 16 份权重读**一次**,服务全部 M 个 token;
sparse GEMV 是**每 token** 重读自己的 ~2 份。字节交叉点:
```text
sparse 读 M × 2 份 vs dense 读 16 份 → M ≈ 8 (TP=2)
```
M > 8 后 dense 更省(且 GEMM 是 compute-bound,tensor core 开始有用)。
所以 sparse 只在 `num_tokens ≤ 8` 启用 —— 覆盖 decode(连续批合并的
多请求 decode 也是小 M)和极短的 re-prefill。真正的 sparse prefill
(按专家对 token 做 permute/gather 的 grouped GEMM,vLLM 的做法)是
后续阶段,主要收益在长 prompt TTFT。
## 5. 实测结果(2026-06-12,完整数据见 `docs/benchmarks/sparse-moe.md`)
In-process decode(bench-gpt-oss,greedy 96 tok):
| | TPOT | tok/s |
|---|---|---|
| dense FP8 TP=2(基线) | 13.9 ms | 72 |
| **sparse FP8 TP=2** | **7.6 ms(1.8×)** | **132** |
| sparse MXFP4 TP=2 | 8.4 ms | 118 |
| sparse FP8 TP=1(单卡) | 7.8 ms | 128 |
Warm-server 对打 llama.cpp(`tools/xserv_vs_llama.py`):
- **TP=2 vs TP=2:xserv 首次全面反超** —— TPOT 7.19-7.32 ms vs llama
7.54-8.42 ms;短/中 prompt TTFT 也领先(35/49 vs 63/65 ms)。
- **TP=1 vs TP=1:llama 大胜**(2.88-3.22 ms vs 7.0-7.2 ms,347 vs 140
tok/s)。单卡才是 llama 的最优配置:它的跨卡 split 在 PCIe 上每 token
损失 ~5 ms,而单卡时它"全模型 4-bit + CUDA graph 整 token 回放"的
优势全部兑现。xserv 的残余 ~7 ms ≈ ~3 ms HBM(其中非专家权重还是
BF16,含 1.16 GB 的 lm_head)+ ~4 ms 启动开销(~200 个 kernel
launch/token,无 CUDA graph)。
- **正确性:GSM8K-100 = 96%**(dense FP8 91% / BF16 90%,greedy 噪声内,
无回归)。
教训:之前"CUDA graph ≈ 无用(~0.5-1.5ms)"的结论是相对 13 ms 的
dense TPOT 而言;专家成本砍掉后,launch 开销变成了最大的单项。
## 6. 下一阶段(按收益排序)
1. **decode CUDA graph**(~2-4 ms):当前最大单项。
2. **非专家权重量化**(~1-1.5 ms):qkv/o + lm_head 仍是 BF16,每 token
白读 ~2.3 GB;llama 是全模型 4-bit。
3. **sparse prefill**(grouped GEMM):长 prompt TTFT 94-120 ms → llama
的 ~30 ms 量级。
4. **W4A4 FP4 tensor core / 带宽调优的 MXFP4 GEMV**:让 4-bit 专家真正
快过 FP8(目前 8.4 vs 7.6 ms,GEMV 效率抵消了字节优势)。

View File

@@ -0,0 +1,111 @@
# Phase 21: gpt-oss decode CUDA Graph + GPU argmax
> 目标:消除 decode 的每 token 固定开销。Phase 20 之后 TPOT ~7ms,其中
> GPU 实际计算只占一部分,剩下是 ~200 个 kernel launch 和 per-token 的
> host 工作。本阶段把**整个 decode step 捕获成一个 CUDA graph**,每 token
> 一次 `cudaGraphLaunch` 回放;顺带把 greedy 采样换成 GPU argmax。
>
> 实现:`crates/xserv-model/src/gpt_oss_graph.rs`(~150 行)+ 三块基础设施。
## 1. CUDA Graph 是什么,为什么有约束
`cudaStreamBeginCapture` 之后,发到该 stream 的 kernel 不执行而是被**录制**;
`EndCapture + Instantiate` 得到可执行图;以后每步 `cudaGraphLaunch` 一次性
重放全部 ~200 个 kernel,host 端开销从 ~200 次 launch 降到 1 次。
代价是三条硬约束,每条都对应一个工程问题:
1. **地址稳定**:录制时烤进图里的全部指针,回放时必须仍然有效且指向正确数据;
2. **capture 期间禁止"不安全"调用**:`cudaMalloc`/同步 memcpy/`cudaDeviceSynchronize`
都会让 capture 报错(error 900);
3. **形状固定**:grid 尺寸被烤死,变 shape 就要重录。
## 2. 为什么 xserv 的 decode 本来就"差一点"就能整图捕获
逐项检查 decode step 的输入,发现绝大部分已经满足地址稳定:
| 每步会变的输入 | 地址 | 内容如何更新 |
|---|---|---|
| block table / context lens | PagedKVCache 的常驻 GPU 缓冲 ✓ | `decode_prepare` 在图外 H2D |
| KV 写入位置 | scatter kernel **从 GPU 上的 context_lens 读** ✓ | 同上 |
| attention 读取范围 | paged kernel 从同一缓冲读 ✓ | 同上 |
| MoE 专家选择 | sparse GEMV 从图内刚写的 `topk_ids` 读 ✓ | 数据依赖,天然支持 |
| token id / position | ✗ 原来是每步从 host slice 上传 | **本阶段改造点** |
也就是说,Phase 11(paged KV)和 Phase 20(sparse MoE)的"数据驱动"设计
无意中已经为 graph 化铺平了路 —— 唯二需要动的是 embedding 的 token id 和
RoPE 的 position:各加一个 device-buffer 变体(`embedding_device_ids` /
`rope_inplace_device_pos`),id/pos 存进两个常驻 4 字节缓冲,每步图外更新。
重构后的结构:
```text
forward_decode_paged = decode_prepare(host 簿记,图外)
+ upload ids/pos(图外)
+ decode_core(纯 GPU,可整段捕获)
+ advance_seq_len(host 簿记,图外)
```
## 3. 三个工程问题
### 3.1 null stream 不可捕获 → thread-local launch stream
全代码库的 kernel 都发射在 legacy null stream 上,而 capture 必须在显式
stream 上。解法:`xserv_cuda::stream` 加一个 **thread-local 当前 stream**
(默认 null,行为与从前逐字节一致),所有 kernel wrapper、cuBLAS 的
`cublasSetStream`、NCCL 的 collective 全部改读它。capture 代码用 RAII guard
(`push_stream`)把 capture stream 装进去,录完自动还原。
顺序正确性:显式 stream 以默认(blocking)方式创建,legacy stream 与其
双向隐式同步,所以图外的 H2D/采样 memcpy 与回放天然有序。
### 3.2 capture 期间禁止 cudaMalloc → "retained warmup" 二段式
中间张量来自 caching allocator;capture 中任何一次 pool miss 都会触发
`cudaMalloc` → error 900。第一版实现就栽在这里:**隔离机制自己制造了
pool miss**(capture 中释放的块被隔离,下一层同尺寸分配就找不到块了)。
解法是把同一个 step 先 eager 跑一遍、但**隔离打开**(`begin_retain`):
释放的块全部扣下不回池 → 跑完后池外恰好积累了"这一步需要的每一块";
把它们整批放回池,再开始 capture —— capture 重复完全相同的分配序列,
每次分配都命中池,一次 cudaMalloc 都不会发生。
(重复执行同一 step 是无害的:KV scatter 往同一个位置重写同样的值。)
### 3.3 回放引用的内存不能被别人拿走 → 隔离仓(quarantine)
capture 录下的中间缓冲在 host 侧早就 Drop 了,但图每次回放都会读写这些
地址。若它们回到分配池、被后续 prefill 拿走长期持有,就是双写损坏。
所以 capture 期间释放的块进入 `RetainedBlocks` 隔离仓,由 graph 对象持有,
graph 销毁时才归还 —— 这些内存在 graph 存活期内被锁定为它专用。
### 3.4 两个顺手的点
- **THREAD_LOCAL capture mode**:GLOBAL 模式下,任何线程的 cudaMalloc 都会
毒化 capture;TP 多 rank 线程并发 capture 必须用 THREAD_LOCAL。
- **NCCL 可以被捕获**:rank 内 `ncclAllReduce` 发在 capture stream 上即可,
TP=2 一次成功(各 rank 录各自的图,回放时 collective 自然配对)。
## 4. 意外的教训:launch 开销没有想象的大,argmax 才是大头
A/B 实测(in-process,FP8,96 tok):
| | TP=1 | TP=2 |
|---|---|---|
| eager + host argmax(Phase 20 末) | 7.5 ms | 7.6 ms |
| graph + host argmax | 6.9 ms | 6.9 ms |
| eager + **GPU argmax** | 6.5 ms | — |
| **graph + GPU argmax** | **5.9 ms** | **5.8 ms** |
- **graph 只省了 ~0.6ms**:decode 循环本来就是全异步的,launch 大部分被
GPU 执行掩盖,"~200 launch ≈ 4ms"的预估错了 —— 优化要测不要猜。
- **GPU argmax 省了 ~1ms**:greedy 采样原来每 token 把 [1, 201088] 的
logits(402KB)同步拷回 host、再扫描 201K 个 bf16。仓库里 Phase 15 就写好
的 argmax kernel(kernel 内归约 + 4 字节 D2H)一直没接到 `sample()` 上。
- 细节:GPU argmax 与 host `max_by` 对**完全相等**的 logits 平局取的索引
不同,greedy 轨迹会在某个平局 token 处分叉 —— 输出同样合法(GSM8K 验证)。
## 5. 结果与剩余瓶颈
`docs/benchmarks/sparse-moe.md` 的 Phase 21 小节(warm-server 对打 llama
的数字以那里为准)。剩余 TPOT 的构成:~3ms 是 HBM 字节(其中非专家权重
仍是 BF16,含 1.16GB 的 lm_head —— **Phase 22 量化它们**),其余是 GEMV
带宽效率与 attention。llama 单卡 2.9ms 的差距主要就在"全模型 4-bit"。

View File

@@ -0,0 +1,111 @@
# Sparse MoE decode — 1.8× over dense; beats llama.cpp at TP=2 (gpt-oss-20b, RTX 5090)
Phase 20 (`docs/20-sparse-moe.md`): decode computes only the routed top-4
experts via fused expert-indexed GEMVs (`csrc/moe/moe_sparse.cu`) instead of
the dense all-local-expert batched GEMM. FP8 weights run W8A16 (weights FP8,
activations BF16 — decode is memory-bound, tensor cores irrelevant at M=1);
MXFP4 runs W4A16. Dense path retained for prefill / `num_tokens > 8` and via
`XSERV_DENSE_MOE=1` for A/B.
## In-process decode (bench-gpt-oss, greedy, 96 tokens)
| config | TPOT | tok/s |
|---|---|---|
| dense FP8 TP=2 (baseline) | 13.9 ms | 72 |
| **sparse FP8 TP=2** | **7.6 ms** | **132** |
| sparse MXFP4 TP=2 | 8.4 ms | 118 |
| sparse FP8 TP=1 (one 5090) | 7.8 ms | 128 |
| sparse MXFP4 TP=1 | 8.9 ms | 113 |
- Sparse FP8 = **1.8× over dense**. Greedy output stays coherent.
- TP=1 ≈ TP=2: expert reads are now so small that PCIe all-reduce eats the
TP gain — single-GPU serving becomes the attractive deployment.
- MXFP4 reads half the bytes of FP8 but stays slower: the 4-bit dequant GEMV
has lower effective bandwidth (same fixed inefficiency seen in the dense
MXFP4 experiments); at sparse sizes both are partly launch/latency-bound.
## Head-to-head vs llama.cpp (tools/xserv_vs_llama.py, warm servers, TP=2, GPUs 0-1, 6 reps, 256 tok)
| prompt | metric | xserv sparse FP8 | llama MXFP4 | xserv vs llama |
|---|---|---|---|---|
| short | TTFT | **35.3 ms** | 62.7 ms | 1.78× faster |
| short | TPOT | **7.32 ms** | 8.42 ms | 1.15× faster |
| medium | TTFT | **49.4 ms** | 65.0 ms | 1.32× faster |
| medium | TPOT | **7.19 ms** | 7.54 ms | 1.05× faster |
| medium | tok/s | **139.1** | 132.7 | |
| long (1.6k) | TTFT | 94.1 ms | **44.7 ms** | 0.48× (llama wins) |
| long | TPOT | **7.25 ms** | 7.64 ms | 1.05× faster |
**Decode TPOT now beats llama.cpp at every prompt length** (was 2× slower:
13.1 vs 6.6 ms before sparse). Remaining loss: long-prompt TTFT — prefill is
still the dense all-expert GEMM; sparse/grouped prefill is the next phase.
**Post-review fixes** (same harness, rerun): removing three leftover
`cudaDeviceSynchronize` from the decode hot path and replacing the CPU-tiled
prefill bias-add (96 D2H/H2D round-trips per prefill) with a GPU broadcast
kernel improved both axes — TPOT 7.19-7.32 → **6.99-7.21 ms**, TTFT
short/medium/long 35/49/94 → **29/42/79 ms**. GSM8K-50: 94% (unchanged).
## TP=1 head-to-head (single 5090; server now routes gpt-oss tp=1 to the TP engine)
| prompt | metric | xserv sparse FP8 | llama MXFP4 |
|---|---|---|---|
| short | TTFT / TPOT | 42.8 ms / 7.00 ms | **34.5 ms / 3.22 ms** |
| medium | TTFT / TPOT | 57.1 ms / 7.19 ms | **37.3 ms / 2.89 ms** |
| long | TTFT / TPOT | 119.6 ms / 7.20 ms | **27.8 ms / 2.88 ms** |
| | tok/s | 139143 | **311347** |
**Single-GPU is llama.cpp's sweet spot and it wins 2.22.5×.** Two structural
reasons, both instructive:
1. llama TP=2 (7.58.4 ms) is much WORSE than its TP=1 (2.9 ms): its PCIe
cross-GPU split costs ~5 ms/token. xserv's NCCL all-reduce is cheap enough
that TP=2 ≈ TP=1 (7.2 vs 7.0 ms) — but xserv's single-GPU floor is high.
2. xserv TP=1 reads ~4.7 GB/token (experts FP8 2.4 GB + **non-expert weights
still BF16** ~2.3 GB, half of that the 201k-vocab lm_head) ≈ 3.1 ms of pure
HBM time; the other ~4 ms is launch overhead (~200 kernels/token, no CUDA
graphs) + BF16 GEMV efficiency. llama reads ~1.3 GB (everything MXFP4) and
replays the whole token as one CUDA graph.
## Correctness
- Greedy generations coherent across prompts (FP8/MXFP4, TP=1/2).
- Sparse FP8 is W8A16 vs dense W8A8 — activations are no longer quantized, so
tokens are not expected to be byte-identical to dense; quality is checked by
GSM8K instead.
- **GSM8K-100 (greedy, TP=2, `tools/eval_gsm8k_fast.py`): 96/100 = 96.0%** vs
dense FP8 91.0% / BF16 90.0% — no regression (within greedy-nondeterminism
noise; W8A16 removes activation-quantization error so ≥ dense is expected).
Avg 1.3 s/problem also reflects the decode speedup.
## Phase 21 update: decode CUDA graph + GPU argmax (docs/21-cuda-graph-decode.md)
The whole batch=1 decode step now replays as one CUDA graph, and greedy
sampling uses the GPU argmax kernel (4-byte D2H instead of a 402 KB logits
copy + 201k-element host scan). In-process A/B: graph 0.6 ms, GPU argmax
1.0 ms. Warm-server head-to-head (same harness/GPUs, 6 reps):
| | xserv FP8 (graph) | llama MXFP4 | |
|---|---|---|---|
| TP=2 TPOT | **5.765.89 ms** (170174 tok/s) | 7.428.45 ms | **xserv 1.261.47×** |
| TP=2 TTFT s/m/l | **25 / 28 / 51 ms** | 63 / 66 / 45 ms | xserv 2.4× s/m; long ~par |
| TP=1 TPOT | 5.785.95 ms | **2.803.22 ms** | llama 2.0× (was 2.5×) |
| TP=1 TTFT s/m | **32 / 35 ms** | 34 / 36 ms | xserv slightly ahead |
GSM8K-50 through the graph path: 47/50 = 94% (unchanged). Note: GPU argmax
breaks exact-tie logits differently than the host scan, so greedy trajectories
can legitimately diverge at a tie token.
## Remaining gaps / next levers (to catch llama TP=1 at 2.8 ms)
Per-token fixed overhead is now mostly gone; the residual ~5.8 ms is
dominated by HBM bytes and kernel efficiency. In impact order:
1. **Quantize non-expert weights** (~1.5 ms): attn qkv/o + the 1.16 GB BF16
lm_head read every token; FP8/MXFP4 them like llama quantizes everything.
2. **GEMV/attention bandwidth tuning**: effective BW of the hand GEMVs is
well under peak; llama's 2.8 ms implies ~85%+ efficiency on ~1.3 GB.
3. **Sparse prefill** (permute tokens by expert + grouped GEMM): long-prompt
TTFT 5175 ms → llama's ~30 ms territory.
4. **W4A4 FP4 tensor cores / bandwidth-tuned MXFP4 GEMV**: make 4-bit experts
actually beat FP8.

View File

@@ -81,7 +81,8 @@ def main():
parser.add_argument("--max-tokens", type=int, default=512, help="Max generation tokens")
parser.add_argument("--tp", type=int, default=1, help="Tensor parallelism")
parser.add_argument("--offset", type=int, default=0, help="Start from problem N")
parser.add_argument("--gpu", type=int, default=0, help="GPU device index")
parser.add_argument("--gpu", type=str, default="0",
help="CUDA_VISIBLE_DEVICES value, e.g. '0' or '2,3' (must cover --tp ranks)")
args = parser.parse_args()
if not DATA_PATH.exists():