style: format Rust workspace

This commit is contained in:
2026-06-18 18:11:58 +08:00
parent 013465fc06
commit 531cd3fe08
57 changed files with 4045 additions and 1204 deletions

View File

@@ -165,7 +165,10 @@ pub fn begin_retain() {
/// Stop quarantining and hand the quarantined blocks to the caller.
pub fn end_retain() -> RetainedBlocks {
RETAINED.with(|cell| {
let list = cell.borrow_mut().take().expect("end_retain without begin_retain");
let list = cell
.borrow_mut()
.take()
.expect("end_retain without begin_retain");
RetainedBlocks(list)
})
}

View File

@@ -48,9 +48,7 @@ pub fn device_info(device: u32) -> Result<DeviceInfo> {
// Heap-allocate oversized buffer for cudaDeviceProp (layout varies by CUDA version).
// CUDA 12.x struct is ~5-6 KB; use 32 KB to guard against future growth.
let mut prop_buf = vec![0u8; 32768];
error::check(unsafe {
ffi::cudaGetDeviceProperties(prop_buf.as_mut_ptr(), device as i32)
})?;
error::check(unsafe { ffi::cudaGetDeviceProperties(prop_buf.as_mut_ptr(), device as i32) })?;
// Name is always the first field: char[256].
let name = unsafe { CStr::from_ptr(prop_buf.as_ptr() as *const c_char) }
.to_string_lossy()

View File

@@ -64,11 +64,5 @@ unsafe extern "C" {
pub fn cudaGraphExecDestroy(graph_exec: CudaGraphExec) -> i32;
// --- Our test kernel ---
pub fn launch_vecadd_f32(
a: *const f32,
b: *const f32,
c: *mut f32,
n: i32,
stream: CudaStream,
);
pub fn launch_vecadd_f32(a: *const f32, b: *const f32, c: *mut f32, n: i32, stream: CudaStream);
}

View File

@@ -54,30 +54,21 @@ impl CudaGraph {
// made by THIS thread invalidate the capture. With GLOBAL mode, TP rank
// threads capturing concurrently would poison each other's captures.
error::check(unsafe {
ffi::cudaStreamBeginCapture(
stream.as_raw(),
ffi::CUDA_STREAM_CAPTURE_MODE_THREAD_LOCAL,
)
ffi::cudaStreamBeginCapture(stream.as_raw(), ffi::CUDA_STREAM_CAPTURE_MODE_THREAD_LOCAL)
})
}
/// End capture and instantiate the executable graph.
pub fn end_capture(&mut self, stream: &CudaStream) -> Result<()> {
error::check(unsafe {
ffi::cudaStreamEndCapture(stream.as_raw(), &mut self.graph)
})?;
error::check(unsafe {
ffi::cudaGraphInstantiate(&mut self.exec, self.graph, 0)
})
error::check(unsafe { ffi::cudaStreamEndCapture(stream.as_raw(), &mut self.graph) })?;
error::check(unsafe { ffi::cudaGraphInstantiate(&mut self.exec, self.graph, 0) })
}
/// Replay the captured graph on `stream`.
/// Panics if no graph has been captured yet.
pub fn launch(&self, stream: &CudaStream) -> Result<()> {
assert!(self.is_ready(), "CudaGraph::launch called before capture");
error::check(unsafe {
ffi::cudaGraphLaunch(self.exec, stream.as_raw())
})
error::check(unsafe { ffi::cudaGraphLaunch(self.exec, stream.as_raw()) })
}
fn destroy_inner(&mut self) {

View File

@@ -11,4 +11,4 @@ pub use device::DeviceInfo;
pub use error::{CudaError, Result};
pub use graph::CudaGraph;
pub use memory::{GpuBuffer, PinnedBuffer};
pub use stream::{current_stream_raw, push_stream, CudaStream, StreamGuard};
pub use stream::{CudaStream, StreamGuard, current_stream_raw, push_stream};

View File

@@ -22,7 +22,12 @@ impl GpuBuffer {
assert!(len > 0, "cannot allocate 0 bytes on GPU");
let mut ptr = std::ptr::null_mut();
error::check(unsafe { ffi::cudaMalloc(&mut ptr, len) })?;
Ok(Self { ptr, len, owned: true, pooled: false })
Ok(Self {
ptr,
len,
owned: true,
pooled: false,
})
}
/// Mark this buffer as pooled (returned to caching allocator on drop)
@@ -92,9 +97,7 @@ impl GpuBuffer {
/// Copy from another GPU buffer (D2D).
pub fn copy_from_device(&mut self, src: &GpuBuffer) -> Result<()> {
let n = src.len.min(self.len);
error::check(unsafe {
ffi::cudaMemcpy(self.ptr, src.ptr, n, ffi::CUDA_MEMCPY_D2D)
})
error::check(unsafe { ffi::cudaMemcpy(self.ptr, src.ptr, n, ffi::CUDA_MEMCPY_D2D) })
}
/// Fill buffer with zeros.
@@ -103,7 +106,13 @@ impl GpuBuffer {
}
/// Copy `count` bytes from `src` buffer at `src_offset` to this buffer at `dst_offset`.
pub fn copy_from_device_at(&mut self, src: &GpuBuffer, src_offset: usize, dst_offset: usize, count: usize) -> Result<()> {
pub fn copy_from_device_at(
&mut self,
src: &GpuBuffer,
src_offset: usize,
dst_offset: usize,
count: usize,
) -> Result<()> {
assert!(src_offset + count <= src.len);
assert!(dst_offset + count <= self.len);
error::check(unsafe {
@@ -117,7 +126,14 @@ impl GpuBuffer {
}
/// Async copy `count` bytes from `src` at `src_offset` to `self` at `dst_offset` on `stream`.
pub fn copy_from_device_at_async(&mut self, src: &GpuBuffer, src_offset: usize, dst_offset: usize, count: usize, stream: &CudaStream) -> Result<()> {
pub fn copy_from_device_at_async(
&mut self,
src: &GpuBuffer,
src_offset: usize,
dst_offset: usize,
count: usize,
stream: &CudaStream,
) -> Result<()> {
assert!(src_offset + count <= src.len);
assert!(dst_offset + count <= self.len);
error::check(unsafe {
@@ -161,9 +177,7 @@ impl GpuBuffer {
/// Async zero fill on stream.
pub fn zero_async(&mut self, stream: &CudaStream) -> Result<()> {
error::check(unsafe {
ffi::cudaMemsetAsync(self.ptr, 0, self.len, stream.as_raw())
})
error::check(unsafe { ffi::cudaMemsetAsync(self.ptr, 0, self.len, stream.as_raw()) })
}
/// Consume the buffer without freeing GPU memory. Returns the raw pointer and length.
@@ -178,7 +192,12 @@ impl GpuBuffer {
/// Reconstruct a GpuBuffer from a raw pointer + length.
/// Safety: ptr must have been allocated with cudaMalloc, len must be correct.
pub unsafe fn from_raw(ptr: *mut u8, len: usize) -> Self {
Self { ptr, len, owned: true, pooled: false }
Self {
ptr,
len,
owned: true,
pooled: false,
}
}
/// Create a non-owning view of GPU memory. Dropping this buffer does NOT
@@ -189,7 +208,12 @@ impl GpuBuffer {
/// `ptr` must point to a valid GPU allocation of at least `len` bytes that
/// will remain live for the lifetime of the returned `GpuBuffer`.
pub unsafe fn borrow_raw(ptr: *mut u8, len: usize) -> Self {
Self { ptr, len, owned: false, pooled: false }
Self {
ptr,
len,
owned: false,
pooled: false,
}
}
}

View File

@@ -14,7 +14,10 @@ fn test_device_info() {
info.compute_major, info.compute_minor
);
println!(" SM Count: {}", info.sm_count);
println!(" Shared Mem/Block: {} KB", info.shared_mem_per_block / 1024);
println!(
" Shared Mem/Block: {} KB",
info.shared_mem_per_block / 1024
);
println!(" Warp Size: {}", info.warp_size);
println!(" Max Threads/Block: {}", info.max_threads_per_block);
@@ -145,7 +148,11 @@ fn test_caching_allocator() {
// Second allocation of same size: should hit cache
let _buf2 = alloc.alloc(1024).unwrap();
assert_eq!(alloc.stats().cuda_malloc_count, 1, "should reuse cached buffer");
assert_eq!(
alloc.stats().cuda_malloc_count,
1,
"should reuse cached buffer"
);
assert_eq!(alloc.stats().cache_hit_count, 1);
}
@@ -198,11 +205,17 @@ fn test_async_copy() {
}
let mut gpu = GpuBuffer::alloc(4096).unwrap();
unsafe { gpu.copy_from_host_async(pinned.as_slice(), &stream).unwrap() };
unsafe {
gpu.copy_from_host_async(pinned.as_slice(), &stream)
.unwrap()
};
stream.synchronize().unwrap();
let mut out_pinned = PinnedBuffer::alloc(4096).unwrap();
unsafe { gpu.copy_to_host_async(out_pinned.as_mut_slice(), &stream).unwrap() };
unsafe {
gpu.copy_to_host_async(out_pinned.as_mut_slice(), &stream)
.unwrap()
};
stream.synchronize().unwrap();
assert_eq!(pinned.as_slice(), out_pinned.as_slice());

View File

@@ -34,7 +34,12 @@ pub const NCCL_SUCCESS: i32 = 0;
unsafe extern "C" {
pub fn ncclGetUniqueId(uid: *mut NcclUniqueId) -> i32;
// ncclUniqueId is passed BY VALUE (a 128-byte struct) per the NCCL ABI.
pub fn ncclCommInitRank(comm: *mut NcclComm, nranks: i32, commid: NcclUniqueId, rank: i32) -> i32;
pub fn ncclCommInitRank(
comm: *mut NcclComm,
nranks: i32,
commid: NcclUniqueId,
rank: i32,
) -> i32;
pub fn ncclCommDestroy(comm: NcclComm) -> i32;
pub fn ncclAllReduce(
sendbuff: *const c_void,
@@ -78,5 +83,10 @@ pub fn err_string(result: i32) -> String {
}
pub fn check(result: i32, what: &str) {
assert_eq!(result, NCCL_SUCCESS, "{what} failed: {}", err_string(result));
assert_eq!(
result,
NCCL_SUCCESS,
"{what} failed: {}",
err_string(result)
);
}

View File

@@ -9,8 +9,8 @@ pub mod ffi;
use std::ffi::c_void;
use ffi::{NcclComm, NcclUniqueId};
use xserv_cuda::device;
use xserv_cuda::GpuBuffer;
use xserv_cuda::device;
pub use ffi::NcclUniqueId as UniqueId;
@@ -55,7 +55,12 @@ impl TpContext {
"ncclCommInitRank",
);
ffi::check(unsafe { ffi::ncclGroupEnd() }, "ncclGroupEnd(init)");
Self { rank, world, device, comm }
Self {
rank,
world,
device,
comm,
}
}
/// In-place AllReduce(sum) over `count` BF16 elements in `buf`.
@@ -127,7 +132,12 @@ impl PpContext {
"ncclCommInitRank",
);
ffi::check(unsafe { ffi::ncclGroupEnd() }, "ncclGroupEnd(init)");
Self { stage, world, device, comm }
Self {
stage,
world,
device,
comm,
}
}
/// Send `count` BF16 elements at `ptr` to `peer`, on the null stream so it is
@@ -138,7 +148,16 @@ impl PpContext {
/// `ptr` must point to at least `count` BF16 elements of valid device memory.
pub fn send_bf16_ptr(&self, ptr: *const c_void, count: usize, peer: usize) {
ffi::check(
unsafe { ffi::ncclSend(ptr, count, ffi::NCCL_BF16, peer as i32, self.comm, launch_stream()) },
unsafe {
ffi::ncclSend(
ptr,
count,
ffi::NCCL_BF16,
peer as i32,
self.comm,
launch_stream(),
)
},
"ncclSend",
);
}
@@ -149,7 +168,16 @@ impl PpContext {
/// `ptr` must point to at least `count` BF16 elements of valid device memory.
pub fn recv_bf16_ptr(&self, ptr: *mut c_void, count: usize, peer: usize) {
ffi::check(
unsafe { ffi::ncclRecv(ptr, count, ffi::NCCL_BF16, peer as i32, self.comm, launch_stream()) },
unsafe {
ffi::ncclRecv(
ptr,
count,
ffi::NCCL_BF16,
peer as i32,
self.comm,
launch_stream(),
)
},
"ncclRecv",
);
}

View File

@@ -2,8 +2,8 @@
use half::bf16;
use std::thread;
use xserv_cuda::{device, GpuBuffer};
use xserv_distributed::{get_unique_id, TpContext};
use xserv_cuda::{GpuBuffer, device};
use xserv_distributed::{TpContext, get_unique_id};
#[test]
fn allreduce_two_gpu_sum() {
@@ -25,9 +25,7 @@ fn allreduce_two_gpu_sum() {
// Rank r fills its buffer with (r + 1).
let val = bf16::from_f32((rank + 1) as f32);
let host = vec![val; n];
let src = unsafe {
std::slice::from_raw_parts(host.as_ptr() as *const u8, n * 2)
};
let src = unsafe { std::slice::from_raw_parts(host.as_ptr() as *const u8, n * 2) };
let mut buf = GpuBuffer::alloc(n * 2).unwrap();
buf.copy_from_host(src).unwrap();

View File

@@ -6,8 +6,8 @@
use half::bf16;
use std::ffi::c_void;
use std::thread;
use xserv_cuda::{device, GpuBuffer};
use xserv_distributed::{get_unique_id, PpContext};
use xserv_cuda::{GpuBuffer, device};
use xserv_distributed::{PpContext, get_unique_id};
#[test]
fn pp_send_recv_two_stages() {
@@ -30,7 +30,8 @@ fn pp_send_recv_two_stages() {
if stage == 0 {
// Fill with a known pattern and send to stage 1.
let host: Vec<bf16> = (0..n).map(|i| bf16::from_f32((i % 97) as f32)).collect();
let src = unsafe { std::slice::from_raw_parts(host.as_ptr() as *const u8, n * 2) };
let src =
unsafe { std::slice::from_raw_parts(host.as_ptr() as *const u8, n * 2) };
buf.copy_from_host(src).unwrap();
pp.send_bf16_ptr(buf.as_mut_ptr() as *const c_void, n, 1);
device::synchronize().unwrap();

View File

@@ -6,78 +6,189 @@ unsafe extern "C" {
fn launch_gelu_bf16(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_silu_f32(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_silu_bf16(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_scale_f32(x: *const c_void, out: *mut c_void, scale: f32, n: i32, stream: *mut c_void);
fn launch_scale_bf16(x: *const c_void, out: *mut c_void, scale: f32, n: i32, stream: *mut c_void);
fn launch_add_f32(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_add_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_mul_f32(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_mul_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_silu_mul_bf16(gate: *const c_void, up: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_gpt_oss_glu_bf16(gate_up: *const c_void, out: *mut c_void, n_elements: i32,
alpha: f32, limit: f32, stream: *mut c_void);
fn launch_bias_add_2d_bf16(x: *const c_void, bias: *const c_void, out: *mut c_void,
rows: i32, cols: i32, stream: *mut c_void);
fn launch_scale_f32(
x: *const c_void,
out: *mut c_void,
scale: f32,
n: i32,
stream: *mut c_void,
);
fn launch_scale_bf16(
x: *const c_void,
out: *mut c_void,
scale: f32,
n: i32,
stream: *mut c_void,
);
fn launch_add_f32(
a: *const c_void,
b: *const c_void,
out: *mut c_void,
n: i32,
stream: *mut c_void,
);
fn launch_add_bf16(
a: *const c_void,
b: *const c_void,
out: *mut c_void,
n: i32,
stream: *mut c_void,
);
fn launch_mul_f32(
a: *const c_void,
b: *const c_void,
out: *mut c_void,
n: i32,
stream: *mut c_void,
);
fn launch_mul_bf16(
a: *const c_void,
b: *const c_void,
out: *mut c_void,
n: i32,
stream: *mut c_void,
);
fn launch_silu_mul_bf16(
gate: *const c_void,
up: *const c_void,
out: *mut c_void,
n: i32,
stream: *mut c_void,
);
fn launch_gpt_oss_glu_bf16(
gate_up: *const c_void,
out: *mut c_void,
n_elements: i32,
alpha: f32,
limit: f32,
stream: *mut c_void,
);
fn launch_bias_add_2d_bf16(
x: *const c_void,
bias: *const c_void,
out: *mut c_void,
rows: i32,
cols: i32,
stream: *mut c_void,
);
}
fn dispatch_unary(x: &Tensor, f32_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void),
bf16_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void)) -> Tensor {
fn dispatch_unary(
x: &Tensor,
f32_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void),
bf16_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void),
) -> Tensor {
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
let n = x.numel();
assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)");
assert!(
n <= i32::MAX as usize,
"tensor too large for i32 kernel param ({n} elements)"
);
let n = n as i32;
unsafe {
match x.dtype() {
DType::F32 => f32_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()),
DType::BF16 => bf16_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()),
DType::F32 => f32_fn(
x.data_ptr() as _,
out.data_ptr() as *mut c_void,
n,
xserv_cuda::current_stream_raw(),
),
DType::BF16 => bf16_fn(
x.data_ptr() as _,
out.data_ptr() as *mut c_void,
n,
xserv_cuda::current_stream_raw(),
),
_ => panic!("unsupported dtype"),
}
}
out
}
fn dispatch_binary(a: &Tensor, b: &Tensor,
f32_fn: unsafe extern "C" fn(*const c_void, *const c_void, *mut c_void, i32, *mut c_void),
bf16_fn: unsafe extern "C" fn(*const c_void, *const c_void, *mut c_void, i32, *mut c_void)) -> Tensor {
fn dispatch_binary(
a: &Tensor,
b: &Tensor,
f32_fn: unsafe extern "C" fn(*const c_void, *const c_void, *mut c_void, i32, *mut c_void),
bf16_fn: unsafe extern "C" fn(*const c_void, *const c_void, *mut c_void, i32, *mut c_void),
) -> Tensor {
assert_eq!(a.shape(), b.shape());
assert!(a.is_contiguous() && b.is_contiguous());
assert!(matches!(a.device(), Device::Cuda(_)));
assert_eq!(a.dtype(), b.dtype());
let out = Tensor::empty(a.shape(), a.dtype(), a.device());
let n = a.numel();
assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)");
assert!(
n <= i32::MAX as usize,
"tensor too large for i32 kernel param ({n} elements)"
);
let n = n as i32;
unsafe {
match a.dtype() {
DType::F32 => f32_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()),
DType::BF16 => bf16_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()),
DType::F32 => f32_fn(
a.data_ptr() as _,
b.data_ptr() as _,
out.data_ptr() as *mut c_void,
n,
xserv_cuda::current_stream_raw(),
),
DType::BF16 => bf16_fn(
a.data_ptr() as _,
b.data_ptr() as _,
out.data_ptr() as *mut c_void,
n,
xserv_cuda::current_stream_raw(),
),
_ => panic!("unsupported dtype"),
}
}
out
}
pub fn gelu(x: &Tensor) -> Tensor { dispatch_unary(x, launch_gelu_f32, launch_gelu_bf16) }
pub fn silu(x: &Tensor) -> Tensor { dispatch_unary(x, launch_silu_f32, launch_silu_bf16) }
pub fn gelu(x: &Tensor) -> Tensor {
dispatch_unary(x, launch_gelu_f32, launch_gelu_bf16)
}
pub fn silu(x: &Tensor) -> Tensor {
dispatch_unary(x, launch_silu_f32, launch_silu_bf16)
}
pub fn scale(x: &Tensor, scale_val: f32) -> Tensor {
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
let n = x.numel();
assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)");
assert!(
n <= i32::MAX as usize,
"tensor too large for i32 kernel param ({n} elements)"
);
let n = n as i32;
unsafe {
match x.dtype() {
DType::F32 => launch_scale_f32(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, xserv_cuda::current_stream_raw()),
DType::BF16 => launch_scale_bf16(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, xserv_cuda::current_stream_raw()),
DType::F32 => launch_scale_f32(
x.data_ptr() as _,
out.data_ptr() as *mut c_void,
scale_val,
n,
xserv_cuda::current_stream_raw(),
),
DType::BF16 => launch_scale_bf16(
x.data_ptr() as _,
out.data_ptr() as *mut c_void,
scale_val,
n,
xserv_cuda::current_stream_raw(),
),
_ => panic!("unsupported dtype for scale"),
}
}
out
}
pub fn add(a: &Tensor, b: &Tensor) -> Tensor { dispatch_binary(a, b, launch_add_f32, launch_add_bf16) }
pub fn mul(a: &Tensor, b: &Tensor) -> Tensor { dispatch_binary(a, b, launch_mul_f32, launch_mul_bf16) }
pub fn add(a: &Tensor, b: &Tensor) -> Tensor {
dispatch_binary(a, b, launch_add_f32, launch_add_bf16)
}
pub fn mul(a: &Tensor, b: &Tensor) -> Tensor {
dispatch_binary(a, b, launch_mul_f32, launch_mul_bf16)
}
/// Row-broadcast bias add: out[r, c] = x[r, c] + bias[c] (BF16 only).
pub fn bias_add_2d(x: &Tensor, bias: &Tensor) -> Tensor {
@@ -89,13 +200,22 @@ pub fn bias_add_2d(x: &Tensor, bias: &Tensor) -> Tensor {
assert!(matches!(x.device(), Device::Cuda(_)));
let rows = x.shape()[0];
let cols = x.shape()[1];
assert_eq!(bias.shape()[0], cols, "bias size {} != cols {cols}", bias.shape()[0]);
assert_eq!(
bias.shape()[0],
cols,
"bias size {} != cols {cols}",
bias.shape()[0]
);
assert!(rows * cols <= i32::MAX as usize);
let out = Tensor::empty(&[rows, cols], DType::BF16, x.device());
unsafe {
launch_bias_add_2d_bf16(
x.data_ptr() as _, bias.data_ptr() as _, out.data_ptr() as *mut c_void,
rows as i32, cols as i32, xserv_cuda::current_stream_raw(),
x.data_ptr() as _,
bias.data_ptr() as _,
out.data_ptr() as *mut c_void,
rows as i32,
cols as i32,
xserv_cuda::current_stream_raw(),
);
}
out
@@ -110,7 +230,10 @@ pub fn silu_mul(gate: &Tensor, up: &Tensor) -> Tensor {
assert_eq!(gate.dtype(), DType::BF16, "silu_mul requires BF16");
let out = Tensor::empty(gate.shape(), gate.dtype(), gate.device());
let n = gate.numel();
assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)");
assert!(
n <= i32::MAX as usize,
"tensor too large for i32 kernel param ({n} elements)"
);
let n = n as i32;
unsafe {
launch_silu_mul_bf16(

View File

@@ -2,8 +2,13 @@ use std::ffi::c_void;
use xserv_tensor::{DType, Device, Tensor};
unsafe extern "C" {
fn launch_argmax_bf16(logits: *const c_void, out_idx: *mut c_void,
rows: i32, cols: i32, stream: *mut c_void);
fn launch_argmax_bf16(
logits: *const c_void,
out_idx: *mut c_void,
rows: i32,
cols: i32,
stream: *mut c_void,
);
}
/// GPU argmax over the last dim of a [rows, cols] BF16 tensor.
@@ -19,7 +24,10 @@ pub fn argmax_bf16_to_host(logits: &Tensor) -> Vec<u32> {
assert_eq!(logits.ndim(), 2, "argmax expects a 2D [rows, cols] tensor");
assert_eq!(logits.dtype(), DType::BF16, "argmax kernel is BF16-only");
assert!(logits.is_contiguous(), "argmax requires contiguous input");
assert!(matches!(logits.device(), Device::Cuda(_)), "argmax requires GPU input");
assert!(
matches!(logits.device(), Device::Cuda(_)),
"argmax requires GPU input"
);
let rows = logits.shape()[0];
let cols = logits.shape()[1];
@@ -35,7 +43,8 @@ pub fn argmax_bf16_to_host(logits: &Tensor) -> Vec<u32> {
launch_argmax_bf16(
logits.data_ptr() as *const c_void,
out.as_mut_ptr() as *mut c_void,
rows as i32, cols as i32,
rows as i32,
cols as i32,
xserv_cuda::current_stream_raw(),
);
}
@@ -44,9 +53,8 @@ pub fn argmax_bf16_to_host(logits: &Tensor) -> Vec<u32> {
out.copy_to_host(&mut host_bytes).expect("argmax D2H");
drop(out); // returned to pool
let host_i32: &[i32] = unsafe {
std::slice::from_raw_parts(host_bytes.as_ptr() as *const i32, rows)
};
let host_i32: &[i32] =
unsafe { std::slice::from_raw_parts(host_bytes.as_ptr() as *const i32, rows) };
host_i32.iter().map(|&v| v as u32).collect()
}
@@ -62,4 +70,3 @@ pub fn argmax_bf16_single(logits: &Tensor) -> u32 {
};
argmax_bf16_to_host(&view)[0]
}

View File

@@ -6,28 +6,67 @@ use crate::gemm::batched_matmul;
use crate::softmax::softmax;
unsafe extern "C" {
fn launch_causal_mask_f32(scores: *mut c_void, batch: i32, rows: i32, cols: i32,
offset: i32, stream: *mut c_void);
fn launch_causal_mask_bf16(scores: *mut c_void, batch: i32, rows: i32, cols: i32,
offset: i32, stream: *mut c_void);
fn launch_causal_mask_f32(
scores: *mut c_void,
batch: i32,
rows: i32,
cols: i32,
offset: i32,
stream: *mut c_void,
);
fn launch_causal_mask_bf16(
scores: *mut c_void,
batch: i32,
rows: i32,
cols: i32,
offset: i32,
stream: *mut c_void,
);
fn launch_flash_attention_bf16(
q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
batch: i32, num_q_heads: i32, num_kv_heads: i32,
q_len: i32, kv_len: i32, head_dim: i32,
scale: f32, causal: i32, stream: *mut c_void,
q: *const c_void,
k: *const c_void,
v: *const c_void,
o: *mut c_void,
batch: i32,
num_q_heads: i32,
num_kv_heads: i32,
q_len: i32,
kv_len: i32,
head_dim: i32,
scale: f32,
causal: i32,
stream: *mut c_void,
);
fn launch_flash_attention_sinks_bf16(
q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
q: *const c_void,
k: *const c_void,
v: *const c_void,
o: *mut c_void,
sinks: *const c_void,
batch: i32, num_q_heads: i32, num_kv_heads: i32,
q_len: i32, kv_len: i32, head_dim: i32,
scale: f32, causal: i32, window_size: i32, stream: *mut c_void,
batch: i32,
num_q_heads: i32,
num_kv_heads: i32,
q_len: i32,
kv_len: i32,
head_dim: i32,
scale: f32,
causal: i32,
window_size: i32,
stream: *mut c_void,
);
fn launch_decode_attention_bf16(
q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
batch: i32, num_q_heads: i32, num_kv_heads: i32,
kv_len: i32, head_dim: i32,
scale: f32, causal: i32, stream: *mut c_void,
q: *const c_void,
k: *const c_void,
v: *const c_void,
o: *mut c_void,
batch: i32,
num_q_heads: i32,
num_kv_heads: i32,
kv_len: i32,
head_dim: i32,
scale: f32,
causal: i32,
stream: *mut c_void,
);
fn launch_paged_decode_attention_bf16(
q: *const c_void,
@@ -36,9 +75,13 @@ unsafe extern "C" {
o: *mut c_void,
block_tables: *const i32,
context_lens: *const i32,
batch: i32, num_q_heads: i32, num_kv_heads: i32,
head_dim: i32, max_blocks_per_seq: i32,
scale: f32, stream: *mut c_void,
batch: i32,
num_q_heads: i32,
num_kv_heads: i32,
head_dim: i32,
max_blocks_per_seq: i32,
scale: f32,
stream: *mut c_void,
);
fn launch_paged_decode_attention_sinks_bf16(
q: *const c_void,
@@ -48,24 +91,40 @@ unsafe extern "C" {
block_tables: *const i32,
context_lens: *const i32,
sinks: *const c_void,
batch: i32, num_q_heads: i32, num_kv_heads: i32,
head_dim: i32, max_blocks_per_seq: i32,
scale: f32, window_size: i32, stream: *mut c_void,
batch: i32,
num_q_heads: i32,
num_kv_heads: i32,
head_dim: i32,
max_blocks_per_seq: i32,
scale: f32,
window_size: i32,
stream: *mut c_void,
);
fn launch_reshape_and_cache_bf16(
k_src: *const c_void, v_src: *const c_void,
k_pool: *mut c_void, v_pool: *mut c_void,
k_src: *const c_void,
v_src: *const c_void,
k_pool: *mut c_void,
v_pool: *mut c_void,
block_ids: *const c_void,
num_tokens: i32, num_heads: i32,
head_dim: i32, start_pos: i32, block_size: i32,
num_tokens: i32,
num_heads: i32,
head_dim: i32,
start_pos: i32,
block_size: i32,
stream: *mut c_void,
);
fn launch_reshape_and_cache_batched_bf16(
k_src: *const c_void, v_src: *const c_void,
k_pool: *mut c_void, v_pool: *mut c_void,
block_tables: *const c_void, kv_lens: *const c_void,
batch: i32, num_heads: i32,
head_dim: i32, block_size: i32, max_blocks_per_seq: i32,
k_src: *const c_void,
v_src: *const c_void,
k_pool: *mut c_void,
v_pool: *mut c_void,
block_tables: *const c_void,
kv_lens: *const c_void,
batch: i32,
num_heads: i32,
head_dim: i32,
block_size: i32,
max_blocks_per_seq: i32,
stream: *mut c_void,
);
}
@@ -84,20 +143,30 @@ unsafe extern "C" {
/// `block_ids_gpu` must contain at least `(start_pos + num_tokens + block_size - 1) / block_size`
/// valid physical block ids.
pub unsafe fn reshape_and_cache_bf16(
k_src: *const c_void, v_src: *const c_void,
k_pool_ptr: *mut c_void, v_pool_ptr: *mut c_void,
k_src: *const c_void,
v_src: *const c_void,
k_pool_ptr: *mut c_void,
v_pool_ptr: *mut c_void,
block_ids_gpu: *const i32,
num_tokens: usize, num_heads: usize,
head_dim: usize, start_pos: usize, block_size: usize,
num_tokens: usize,
num_heads: usize,
head_dim: usize,
start_pos: usize,
block_size: usize,
stream: *mut c_void,
) {
unsafe {
launch_reshape_and_cache_bf16(
k_src, v_src,
k_pool_ptr, v_pool_ptr,
k_src,
v_src,
k_pool_ptr,
v_pool_ptr,
block_ids_gpu as *const c_void,
num_tokens as i32, num_heads as i32,
head_dim as i32, start_pos as i32, block_size as i32,
num_tokens as i32,
num_heads as i32,
head_dim as i32,
start_pos as i32,
block_size as i32,
stream,
);
}
@@ -113,21 +182,32 @@ pub unsafe fn reshape_and_cache_bf16(
/// All pointers must be on the same GPU. `block_tables` and `kv_lens` must
/// already be synced to the device for the active batch.
pub unsafe fn reshape_and_cache_batched_bf16(
k_src: *const c_void, v_src: *const c_void,
k_pool_ptr: *mut c_void, v_pool_ptr: *mut c_void,
block_tables_gpu: *const i32, kv_lens_gpu: *const i32,
batch: usize, num_heads: usize,
head_dim: usize, block_size: usize, max_blocks_per_seq: usize,
k_src: *const c_void,
v_src: *const c_void,
k_pool_ptr: *mut c_void,
v_pool_ptr: *mut c_void,
block_tables_gpu: *const i32,
kv_lens_gpu: *const i32,
batch: usize,
num_heads: usize,
head_dim: usize,
block_size: usize,
max_blocks_per_seq: usize,
stream: *mut c_void,
) {
unsafe {
launch_reshape_and_cache_batched_bf16(
k_src, v_src,
k_pool_ptr, v_pool_ptr,
k_src,
v_src,
k_pool_ptr,
v_pool_ptr,
block_tables_gpu as *const c_void,
kv_lens_gpu as *const c_void,
batch as i32, num_heads as i32,
head_dim as i32, block_size as i32, max_blocks_per_seq as i32,
batch as i32,
num_heads as i32,
head_dim as i32,
block_size as i32,
max_blocks_per_seq as i32,
stream,
);
}
@@ -143,12 +223,18 @@ fn apply_causal_mask(scores: &Tensor, offset: usize) {
match scores.dtype() {
DType::F32 => launch_causal_mask_f32(
scores.data_ptr() as *mut c_void,
batch as i32, rows as i32, cols as i32, offset as i32,
batch as i32,
rows as i32,
cols as i32,
offset as i32,
xserv_cuda::current_stream_raw(),
),
DType::BF16 => launch_causal_mask_bf16(
scores.data_ptr() as *mut c_void,
batch as i32, rows as i32, cols as i32, offset as i32,
batch as i32,
rows as i32,
cols as i32,
offset as i32,
xserv_cuda::current_stream_raw(),
),
_ => panic!("unsupported dtype for causal mask"),
@@ -214,11 +300,7 @@ pub fn decode_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Tensor {
let kv_len = k.shape()[2];
let scale = 1.0 / (head_dim as f32).sqrt();
let output = Tensor::empty(
&[batch, num_q_heads, 1, head_dim],
DType::BF16,
q.device(),
);
let output = Tensor::empty(&[batch, num_q_heads, 1, head_dim], DType::BF16, q.device());
unsafe {
launch_decode_attention_bf16(
@@ -266,8 +348,14 @@ pub fn flash_attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tens
assert_eq!(k.shape(), &[batch, num_kv_heads, kv_len, head_dim]);
assert_eq!(v.shape(), &[batch, num_kv_heads, kv_len, head_dim]);
assert!(num_q_heads % num_kv_heads == 0, "num_q_heads must be divisible by num_kv_heads");
assert!(head_dim <= 128, "flash_attention supports head_dim up to 128");
assert!(
num_q_heads % num_kv_heads == 0,
"num_q_heads must be divisible by num_kv_heads"
);
assert!(
head_dim <= 128,
"flash_attention supports head_dim up to 128"
);
// Dispatch to specialized decode kernel for single-token generation
if q_len == 1 {
@@ -333,10 +421,18 @@ pub fn flash_attention_sinks(
assert_eq!(v.shape(), &[batch, num_kv_heads, kv_len, head_dim]);
assert!(num_q_heads % num_kv_heads == 0);
assert!(head_dim <= 128);
assert_eq!(sinks.shape()[0], num_q_heads, "sinks must have num_q_heads entries");
assert_eq!(
sinks.shape()[0],
num_q_heads,
"sinks must have num_q_heads entries"
);
let scale = 1.0 / (head_dim as f32).sqrt();
let output = Tensor::empty(&[batch, num_q_heads, q_len, head_dim], DType::BF16, q.device());
let output = Tensor::empty(
&[batch, num_q_heads, q_len, head_dim],
DType::BF16,
q.device(),
);
unsafe {
launch_flash_attention_sinks_bf16(
@@ -383,17 +479,20 @@ pub fn paged_decode_attention(
max_blocks_per_seq: usize,
) -> Tensor {
assert_eq!(q.ndim(), 4);
assert_eq!(q.shape()[2], 1, "paged_decode_attention requires q_len == 1");
assert_eq!(
q.shape()[2],
1,
"paged_decode_attention requires q_len == 1"
);
assert_eq!(q.dtype(), DType::BF16);
assert!(num_q_heads % num_kv_heads == 0, "GQA: num_q_heads must be divisible by num_kv_heads");
assert!(
num_q_heads % num_kv_heads == 0,
"GQA: num_q_heads must be divisible by num_kv_heads"
);
assert!(head_dim <= 128);
let scale = 1.0 / (head_dim as f32).sqrt();
let output = Tensor::empty(
&[batch, num_q_heads, 1, head_dim],
DType::BF16,
q.device(),
);
let output = Tensor::empty(&[batch, num_q_heads, 1, head_dim], DType::BF16, q.device());
unsafe {
launch_paged_decode_attention_bf16(
@@ -442,11 +541,7 @@ pub fn paged_decode_attention_sinks(
assert!(head_dim <= 128);
let scale = 1.0 / (head_dim as f32).sqrt();
let output = Tensor::empty(
&[batch, num_q_heads, 1, head_dim],
DType::BF16,
q.device(),
);
let output = Tensor::empty(&[batch, num_q_heads, 1, head_dim], DType::BF16, q.device());
unsafe {
launch_paged_decode_attention_sinks_bf16(

View File

@@ -5,104 +5,302 @@ use std::ffi::c_void;
// Re-declare the extern functions we need (same as in the individual modules)
unsafe extern "C" {
fn launch_rmsnorm_bf16(x: *const c_void, gamma: *const c_void, out: *mut c_void,
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
fn launch_add_rmsnorm_bf16(x: *const c_void, residual: *const c_void, gamma: *const c_void,
normed_out: *mut c_void, sum_out: *mut c_void,
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
fn launch_silu_mul_bf16(gate: *const c_void, up: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_add_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_embedding_bf16(table: *const c_void, token_ids: *const c_void, out: *mut c_void,
num_tokens: i32, hidden_size: i32, vocab_size: i32, stream: *mut c_void);
fn launch_reshape_heads_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
fn launch_merge_heads_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
fn launch_transpose_hsd_to_shd_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
fn launch_transpose_shd_to_hsd_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
fn launch_rope_bf16(x: *mut c_void, cos_cache: *const c_void, sin_cache: *const c_void,
positions: *const c_void, num_tokens: i32, num_heads: i32,
head_dim: i32, stream: *mut c_void);
fn launch_gemv_bf16(x: *const c_void, w: *const c_void, y_bf16: *mut c_void, y_fp32_buf: *mut c_void,
k: i32, n: i32, stream: *mut c_void);
fn launch_rmsnorm_bf16(
x: *const c_void,
gamma: *const c_void,
out: *mut c_void,
rows: i32,
hidden_size: i32,
eps: f32,
stream: *mut c_void,
);
fn launch_add_rmsnorm_bf16(
x: *const c_void,
residual: *const c_void,
gamma: *const c_void,
normed_out: *mut c_void,
sum_out: *mut c_void,
rows: i32,
hidden_size: i32,
eps: f32,
stream: *mut c_void,
);
fn launch_silu_mul_bf16(
gate: *const c_void,
up: *const c_void,
out: *mut c_void,
n: i32,
stream: *mut c_void,
);
fn launch_add_bf16(
a: *const c_void,
b: *const c_void,
out: *mut c_void,
n: i32,
stream: *mut c_void,
);
fn launch_embedding_bf16(
table: *const c_void,
token_ids: *const c_void,
out: *mut c_void,
num_tokens: i32,
hidden_size: i32,
vocab_size: i32,
stream: *mut c_void,
);
fn launch_reshape_heads_bf16(
inp: *const c_void,
out: *mut c_void,
seq_len: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
);
fn launch_merge_heads_bf16(
inp: *const c_void,
out: *mut c_void,
seq_len: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
);
fn launch_transpose_hsd_to_shd_bf16(
inp: *const c_void,
out: *mut c_void,
seq_len: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
);
fn launch_transpose_shd_to_hsd_bf16(
inp: *const c_void,
out: *mut c_void,
seq_len: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
);
fn launch_rope_bf16(
x: *mut c_void,
cos_cache: *const c_void,
sin_cache: *const c_void,
positions: *const c_void,
num_tokens: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
);
fn launch_gemv_bf16(
x: *const c_void,
w: *const c_void,
y_bf16: *mut c_void,
y_fp32_buf: *mut c_void,
k: i32,
n: i32,
stream: *mut c_void,
);
fn launch_decode_attention_bf16(
q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
batch: i32, num_q_heads: i32, num_kv_heads: i32,
kv_len: i32, head_dim: i32,
scale: f32, causal: i32, stream: *mut c_void,
q: *const c_void,
k: *const c_void,
v: *const c_void,
o: *mut c_void,
batch: i32,
num_q_heads: i32,
num_kv_heads: i32,
kv_len: i32,
head_dim: i32,
scale: f32,
causal: i32,
stream: *mut c_void,
);
}
/// Raw rmsnorm dispatch: writes to pre-allocated `out`.
pub unsafe fn rmsnorm_bf16(x: *const c_void, gamma: *const c_void, out: *mut c_void,
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void) {
pub unsafe fn rmsnorm_bf16(
x: *const c_void,
gamma: *const c_void,
out: *mut c_void,
rows: i32,
hidden_size: i32,
eps: f32,
stream: *mut c_void,
) {
launch_rmsnorm_bf16(x, gamma, out, rows, hidden_size, eps, stream);
}
/// Raw add_rmsnorm dispatch.
pub unsafe fn add_rmsnorm_bf16(x: *const c_void, residual: *const c_void, gamma: *const c_void,
normed_out: *mut c_void, sum_out: *mut c_void,
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void) {
launch_add_rmsnorm_bf16(x, residual, gamma, normed_out, sum_out, rows, hidden_size, eps, stream);
pub unsafe fn add_rmsnorm_bf16(
x: *const c_void,
residual: *const c_void,
gamma: *const c_void,
normed_out: *mut c_void,
sum_out: *mut c_void,
rows: i32,
hidden_size: i32,
eps: f32,
stream: *mut c_void,
) {
launch_add_rmsnorm_bf16(
x,
residual,
gamma,
normed_out,
sum_out,
rows,
hidden_size,
eps,
stream,
);
}
/// Raw silu_mul dispatch.
pub unsafe fn silu_mul_bf16(gate: *const c_void, up: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void) {
pub unsafe fn silu_mul_bf16(
gate: *const c_void,
up: *const c_void,
out: *mut c_void,
n: i32,
stream: *mut c_void,
) {
launch_silu_mul_bf16(gate, up, out, n, stream);
}
/// Raw add dispatch.
pub unsafe fn add_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void) {
pub unsafe fn add_bf16(
a: *const c_void,
b: *const c_void,
out: *mut c_void,
n: i32,
stream: *mut c_void,
) {
launch_add_bf16(a, b, out, n, stream);
}
/// Raw embedding dispatch.
pub unsafe fn embedding_bf16(table: *const c_void, token_ids: *const c_void, out: *mut c_void,
num_tokens: i32, hidden_size: i32, vocab_size: i32, stream: *mut c_void) {
launch_embedding_bf16(table, token_ids, out, num_tokens, hidden_size, vocab_size, stream);
pub unsafe fn embedding_bf16(
table: *const c_void,
token_ids: *const c_void,
out: *mut c_void,
num_tokens: i32,
hidden_size: i32,
vocab_size: i32,
stream: *mut c_void,
) {
launch_embedding_bf16(
table,
token_ids,
out,
num_tokens,
hidden_size,
vocab_size,
stream,
);
}
/// Raw reshape_heads dispatch.
pub unsafe fn reshape_heads_bf16(inp: *const c_void, out: *mut c_void,
seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void) {
pub unsafe fn reshape_heads_bf16(
inp: *const c_void,
out: *mut c_void,
seq_len: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
) {
launch_reshape_heads_bf16(inp, out, seq_len, num_heads, head_dim, stream);
}
/// Raw merge_heads dispatch.
pub unsafe fn merge_heads_bf16(inp: *const c_void, out: *mut c_void,
seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void) {
pub unsafe fn merge_heads_bf16(
inp: *const c_void,
out: *mut c_void,
seq_len: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
) {
launch_merge_heads_bf16(inp, out, seq_len, num_heads, head_dim, stream);
}
/// Raw transpose HSD->SHD dispatch.
pub unsafe fn transpose_hsd_to_shd_bf16(inp: *const c_void, out: *mut c_void,
seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void) {
pub unsafe fn transpose_hsd_to_shd_bf16(
inp: *const c_void,
out: *mut c_void,
seq_len: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
) {
launch_transpose_hsd_to_shd_bf16(inp, out, seq_len, num_heads, head_dim, stream);
}
/// Raw transpose SHD->HSD dispatch.
pub unsafe fn transpose_shd_to_hsd_bf16(inp: *const c_void, out: *mut c_void,
seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void) {
pub unsafe fn transpose_shd_to_hsd_bf16(
inp: *const c_void,
out: *mut c_void,
seq_len: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
) {
launch_transpose_shd_to_hsd_bf16(inp, out, seq_len, num_heads, head_dim, stream);
}
/// Raw RoPE dispatch (in-place).
pub unsafe fn rope_bf16(x: *mut c_void, cos_cache: *const c_void, sin_cache: *const c_void,
positions: *const c_void, num_tokens: i32, num_heads: i32,
head_dim: i32, stream: *mut c_void) {
launch_rope_bf16(x, cos_cache, sin_cache, positions, num_tokens, num_heads, head_dim, stream);
pub unsafe fn rope_bf16(
x: *mut c_void,
cos_cache: *const c_void,
sin_cache: *const c_void,
positions: *const c_void,
num_tokens: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
) {
launch_rope_bf16(
x, cos_cache, sin_cache, positions, num_tokens, num_heads, head_dim, stream,
);
}
/// Raw GEMV dispatch (BF16, M=1). Caller must provide fp32 accumulator buffer.
pub unsafe fn gemv_bf16(x: *const c_void, w: *const c_void, y_bf16: *mut c_void,
y_fp32_buf: *mut c_void, k: i32, n: i32, stream: *mut c_void) {
pub unsafe fn gemv_bf16(
x: *const c_void,
w: *const c_void,
y_bf16: *mut c_void,
y_fp32_buf: *mut c_void,
k: i32,
n: i32,
stream: *mut c_void,
) {
launch_gemv_bf16(x, w, y_bf16, y_fp32_buf, k, n, stream);
}
/// Raw decode attention dispatch.
pub unsafe fn decode_attention_bf16(q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
batch: i32, num_q_heads: i32, num_kv_heads: i32,
kv_len: i32, head_dim: i32,
scale: f32, stream: *mut c_void) {
launch_decode_attention_bf16(q, k, v, o, batch, num_q_heads, num_kv_heads, kv_len, head_dim, scale, 1, stream);
pub unsafe fn decode_attention_bf16(
q: *const c_void,
k: *const c_void,
v: *const c_void,
o: *mut c_void,
batch: i32,
num_q_heads: i32,
num_kv_heads: i32,
kv_len: i32,
head_dim: i32,
scale: f32,
stream: *mut c_void,
) {
launch_decode_attention_bf16(
q,
k,
v,
o,
batch,
num_q_heads,
num_kv_heads,
kv_len,
head_dim,
scale,
1,
stream,
);
}
// cuBLAS FFI

View File

@@ -2,10 +2,24 @@ use std::ffi::c_void;
use xserv_tensor::{DType, Device, Tensor};
unsafe extern "C" {
fn launch_embedding_f32(table: *const c_void, token_ids: *const c_void, out: *mut c_void,
num_tokens: i32, hidden_size: i32, vocab_size: i32, stream: *mut c_void);
fn launch_embedding_bf16(table: *const c_void, token_ids: *const c_void, out: *mut c_void,
num_tokens: i32, hidden_size: i32, vocab_size: i32, stream: *mut c_void);
fn launch_embedding_f32(
table: *const c_void,
token_ids: *const c_void,
out: *mut c_void,
num_tokens: i32,
hidden_size: i32,
vocab_size: i32,
stream: *mut c_void,
);
fn launch_embedding_bf16(
table: *const c_void,
token_ids: *const c_void,
out: *mut c_void,
num_tokens: i32,
hidden_size: i32,
vocab_size: i32,
stream: *mut c_void,
);
}
/// Embedding lookup: table[token_ids[i]] for each i.
@@ -18,8 +32,14 @@ pub fn embedding(table: &Tensor, token_ids: &[u32]) -> Tensor {
let hidden_size = table.shape()[1];
let num_tokens = token_ids.len();
let vocab_size = table.shape()[0];
assert!(num_tokens <= i32::MAX as usize, "too many tokens for i32 kernel param");
assert!(hidden_size <= i32::MAX as usize, "hidden_size too large for i32 kernel param");
assert!(
num_tokens <= i32::MAX as usize,
"too many tokens for i32 kernel param"
);
assert!(
hidden_size <= i32::MAX as usize,
"hidden_size too large for i32 kernel param"
);
// Upload token_ids to GPU
let ids_bytes = unsafe {
@@ -28,11 +48,15 @@ pub fn embedding(table: &Tensor, token_ids: &[u32]) -> Tensor {
num_tokens * std::mem::size_of::<u32>(),
)
};
let mut ids_gpu = xserv_cuda::allocator::cached_alloc(ids_bytes.len()).expect("alloc token_ids");
let mut ids_gpu =
xserv_cuda::allocator::cached_alloc(ids_bytes.len()).expect("alloc token_ids");
ids_gpu.copy_from_host(ids_bytes).unwrap();
for &tid in token_ids {
assert!((tid as usize) < vocab_size, "token_id {tid} out of bounds (vocab_size={vocab_size})");
assert!(
(tid as usize) < vocab_size,
"token_id {tid} out of bounds (vocab_size={vocab_size})"
);
}
embedding_device_ids(table, ids_gpu.as_ptr() as *const c_void, num_tokens)
@@ -53,14 +77,22 @@ pub fn embedding_device_ids(table: &Tensor, ids_gpu: *const c_void, num_tokens:
unsafe {
match table.dtype() {
DType::F32 => launch_embedding_f32(
table.data_ptr() as _, ids_gpu,
table.data_ptr() as _,
ids_gpu,
out.data_ptr() as *mut c_void,
num_tokens as i32, hidden_size as i32, vocab_size as i32, xserv_cuda::current_stream_raw(),
num_tokens as i32,
hidden_size as i32,
vocab_size as i32,
xserv_cuda::current_stream_raw(),
),
DType::BF16 => launch_embedding_bf16(
table.data_ptr() as _, ids_gpu,
table.data_ptr() as _,
ids_gpu,
out.data_ptr() as *mut c_void,
num_tokens as i32, hidden_size as i32, vocab_size as i32, xserv_cuda::current_stream_raw(),
num_tokens as i32,
hidden_size as i32,
vocab_size as i32,
xserv_cuda::current_stream_raw(),
),
_ => panic!("unsupported dtype for embedding"),
}

View File

@@ -1,14 +1,22 @@
use std::cell::RefCell;
use std::ffi::c_void;
use xserv_cuda::error::{self, Result};
use xserv_cuda::GpuBuffer;
use xserv_cuda::error::{self, Result};
use xserv_tensor::{DType, Device, Tensor};
const CUBLAS_WORKSPACE_BYTES: usize = 32 * 1024 * 1024;
// GEMV: single-kernel, no FP32 temp buffer needed
unsafe extern "C" {
fn launch_gemv_bf16(x: *const c_void, w: *const c_void, y_bf16: *mut c_void, y_fp32_buf: *mut c_void, k: i32, n: i32, stream: *mut c_void);
fn launch_gemv_bf16(
x: *const c_void,
w: *const c_void,
y_bf16: *mut c_void,
y_fp32_buf: *mut c_void,
k: i32,
n: i32,
stream: *mut c_void,
);
}
#[derive(Debug, Clone, Copy)]
@@ -20,10 +28,42 @@ pub enum GemmBackend {
// --- FFI: custom CUDA kernels ---
unsafe extern "C" {
fn launch_gemm_naive_f32(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
fn launch_gemm_naive_bf16(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
fn launch_gemm_tiled_f32(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
fn launch_gemm_tiled_bf16(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
fn launch_gemm_naive_f32(
a: *const c_void,
b: *const c_void,
c: *mut c_void,
m: i32,
n: i32,
k: i32,
stream: *mut c_void,
);
fn launch_gemm_naive_bf16(
a: *const c_void,
b: *const c_void,
c: *mut c_void,
m: i32,
n: i32,
k: i32,
stream: *mut c_void,
);
fn launch_gemm_tiled_f32(
a: *const c_void,
b: *const c_void,
c: *mut c_void,
m: i32,
n: i32,
k: i32,
stream: *mut c_void,
);
fn launch_gemm_tiled_bf16(
a: *const c_void,
b: *const c_void,
c: *mut c_void,
m: i32,
n: i32,
k: i32,
stream: *mut c_void,
);
}
// --- FFI: cuBLAS ---
@@ -46,25 +86,46 @@ unsafe extern "C" {
fn cublasSetWorkspace_v2(handle: CublasHandle, workspace: *mut c_void, size: usize) -> i32;
fn cublasGemmEx(
handle: CublasHandle,
transa: i32, transb: i32,
m: i32, n: i32, k: i32,
transa: i32,
transb: i32,
m: i32,
n: i32,
k: i32,
alpha: *const c_void,
a: *const c_void, a_type: i32, lda: i32,
b: *const c_void, b_type: i32, ldb: i32,
a: *const c_void,
a_type: i32,
lda: i32,
b: *const c_void,
b_type: i32,
ldb: i32,
beta: *const c_void,
c: *mut c_void, c_type: i32, ldc: i32,
c: *mut c_void,
c_type: i32,
ldc: i32,
compute_type: i32,
algo: i32,
) -> i32;
fn cublasGemmStridedBatchedEx(
handle: CublasHandle,
transa: i32, transb: i32,
m: i32, n: i32, k: i32,
transa: i32,
transb: i32,
m: i32,
n: i32,
k: i32,
alpha: *const c_void,
a: *const c_void, a_type: i32, lda: i32, stride_a: i64,
b: *const c_void, b_type: i32, ldb: i32, stride_b: i64,
a: *const c_void,
a_type: i32,
lda: i32,
stride_a: i64,
b: *const c_void,
b_type: i32,
ldb: i32,
stride_b: i64,
beta: *const c_void,
c: *mut c_void, c_type: i32, ldc: i32, stride_c: i64,
c: *mut c_void,
c_type: i32,
ldc: i32,
stride_c: i64,
batch_count: i32,
compute_type: i32,
algo: i32,
@@ -89,9 +150,16 @@ impl CublasContext {
// set, so we keep the GpuBuffer in this struct.
let mut workspace = GpuBuffer::alloc(CUBLAS_WORKSPACE_BYTES)?;
error::check(unsafe {
cublasSetWorkspace_v2(handle, workspace.as_mut_ptr() as *mut c_void, CUBLAS_WORKSPACE_BYTES)
cublasSetWorkspace_v2(
handle,
workspace.as_mut_ptr() as *mut c_void,
CUBLAS_WORKSPACE_BYTES,
)
})?;
Ok(Self { handle, _workspace: Some(workspace) })
Ok(Self {
handle,
_workspace: Some(workspace),
})
}
}
@@ -123,9 +191,7 @@ where
/// Get the thread-local cuBLAS handle for use with dispatch module.
pub fn cublas_handle() -> CublasHandle {
CUBLAS_CTX.with(|cell| {
cell.borrow().handle
})
CUBLAS_CTX.with(|cell| cell.borrow().handle)
}
/// Matrix multiplication: C = A @ B
@@ -136,8 +202,14 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
assert_eq!(b.ndim(), 2);
assert_eq!(a.shape()[1], b.shape()[0], "inner dimension mismatch");
assert_eq!(a.dtype(), b.dtype(), "dtype mismatch");
assert!(a.is_contiguous() && b.is_contiguous(), "matmul requires contiguous tensors");
assert!(matches!(a.device(), Device::Cuda(_)), "matmul requires GPU tensors");
assert!(
a.is_contiguous() && b.is_contiguous(),
"matmul requires contiguous tensors"
);
assert!(
matches!(a.device(), Device::Cuda(_)),
"matmul requires GPU tensors"
);
let m = a.shape()[0];
let k = a.shape()[1];
@@ -154,32 +226,63 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
let null_stream = xserv_cuda::current_stream_raw();
match backend {
GemmBackend::Naive => {
unsafe {
match dtype {
DType::F32 => launch_gemm_naive_f32(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
DType::BF16 => launch_gemm_naive_bf16(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
_ => panic!("unsupported dtype for naive GEMM"),
}
GemmBackend::Naive => unsafe {
match dtype {
DType::F32 => launch_gemm_naive_f32(
a_ptr,
b_ptr,
c_ptr,
m as i32,
n as i32,
k as i32,
null_stream,
),
DType::BF16 => launch_gemm_naive_bf16(
a_ptr,
b_ptr,
c_ptr,
m as i32,
n as i32,
k as i32,
null_stream,
),
_ => panic!("unsupported dtype for naive GEMM"),
}
}
GemmBackend::Tiled => {
unsafe {
match dtype {
DType::F32 => launch_gemm_tiled_f32(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
DType::BF16 => launch_gemm_tiled_bf16(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
_ => panic!("unsupported dtype for tiled GEMM"),
}
},
GemmBackend::Tiled => unsafe {
match dtype {
DType::F32 => launch_gemm_tiled_f32(
a_ptr,
b_ptr,
c_ptr,
m as i32,
n as i32,
k as i32,
null_stream,
),
DType::BF16 => launch_gemm_tiled_bf16(
a_ptr,
b_ptr,
c_ptr,
m as i32,
n as i32,
k as i32,
null_stream,
),
_ => panic!("unsupported dtype for tiled GEMM"),
}
}
},
GemmBackend::CuBlas => {
if m == 1 && dtype == DType::BF16 && n >= 256 {
let mut fp32_buf = xserv_cuda::allocator::cached_alloc(n * 4).unwrap();
unsafe {
launch_gemv_bf16(
a_ptr, b_ptr, c_ptr,
a_ptr,
b_ptr,
c_ptr,
fp32_buf.as_mut_ptr() as *mut c_void,
k as i32, n as i32,
k as i32,
n as i32,
null_stream,
);
}
@@ -197,16 +300,26 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
cublasSetStream_v2(handle, null_stream);
error::check(cublasGemmEx(
handle,
CUBLAS_OP_N, CUBLAS_OP_N,
n as i32, m as i32, k as i32,
CUBLAS_OP_N,
CUBLAS_OP_N,
n as i32,
m as i32,
k as i32,
&alpha as *const f32 as *const c_void,
b_ptr, b_type, n as i32,
a_ptr, a_type, k as i32,
b_ptr,
b_type,
n as i32,
a_ptr,
a_type,
k as i32,
&beta as *const f32 as *const c_void,
c_ptr, c_type, n as i32,
c_ptr,
c_type,
n as i32,
CUBLAS_COMPUTE_32F,
-1,
)).expect("cuBLAS GEMM failed");
))
.expect("cuBLAS GEMM failed");
});
}
}
@@ -264,17 +377,30 @@ pub fn batched_matmul(a: &Tensor, b: &Tensor) -> Tensor {
// Row-major trick: C = A @ B ⟺ C^T = B^T @ A^T (col-major)
error::check(cublasGemmStridedBatchedEx(
handle,
CUBLAS_OP_N, CUBLAS_OP_N,
n as i32, m as i32, k as i32,
CUBLAS_OP_N,
CUBLAS_OP_N,
n as i32,
m as i32,
k as i32,
&alpha as *const f32 as *const c_void,
b.data_ptr() as _, b_type, n as i32, stride_b,
a.data_ptr() as _, a_type, k as i32, stride_a,
b.data_ptr() as _,
b_type,
n as i32,
stride_b,
a.data_ptr() as _,
a_type,
k as i32,
stride_a,
&beta as *const f32 as *const c_void,
c.data_ptr() as *mut c_void, c_type, n as i32, stride_c,
c.data_ptr() as *mut c_void,
c_type,
n as i32,
stride_c,
batch as i32,
CUBLAS_COMPUTE_32F,
-1,
)).expect("cuBLAS batched GEMM failed");
))
.expect("cuBLAS batched GEMM failed");
});
c
}

View File

@@ -2,10 +2,26 @@ use std::ffi::c_void;
use xserv_tensor::{DType, Device, Tensor};
unsafe extern "C" {
fn launch_layernorm_f32(x: *const c_void, gamma: *const c_void, beta: *const c_void,
out: *mut c_void, rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
fn launch_layernorm_bf16(x: *const c_void, gamma: *const c_void, beta: *const c_void,
out: *mut c_void, rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
fn launch_layernorm_f32(
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
out: *mut c_void,
rows: i32,
hidden_size: i32,
eps: f32,
stream: *mut c_void,
);
fn launch_layernorm_bf16(
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
out: *mut c_void,
rows: i32,
hidden_size: i32,
eps: f32,
stream: *mut c_void,
);
}
pub fn layernorm(x: &Tensor, gamma: &Tensor, beta: &Tensor, eps: f32) -> Tensor {
@@ -17,21 +33,37 @@ pub fn layernorm(x: &Tensor, gamma: &Tensor, beta: &Tensor, eps: f32) -> Tensor
assert_eq!(beta.shape(), &[hidden_size]);
let rows = x.numel() / hidden_size;
assert!(rows <= i32::MAX as usize, "too many rows for i32 kernel param");
assert!(hidden_size <= i32::MAX as usize, "hidden_size too large for i32 kernel param");
assert!(
rows <= i32::MAX as usize,
"too many rows for i32 kernel param"
);
assert!(
hidden_size <= i32::MAX as usize,
"hidden_size too large for i32 kernel param"
);
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
unsafe {
match x.dtype() {
DType::F32 => launch_layernorm_f32(
x.data_ptr() as _, gamma.data_ptr() as _, beta.data_ptr() as _,
x.data_ptr() as _,
gamma.data_ptr() as _,
beta.data_ptr() as _,
out.data_ptr() as *mut c_void,
rows as i32, hidden_size as i32, eps, xserv_cuda::current_stream_raw(),
rows as i32,
hidden_size as i32,
eps,
xserv_cuda::current_stream_raw(),
),
DType::BF16 => launch_layernorm_bf16(
x.data_ptr() as _, gamma.data_ptr() as _, beta.data_ptr() as _,
x.data_ptr() as _,
gamma.data_ptr() as _,
beta.data_ptr() as _,
out.data_ptr() as *mut c_void,
rows as i32, hidden_size as i32, eps, xserv_cuda::current_stream_raw(),
rows as i32,
hidden_size as i32,
eps,
xserv_cuda::current_stream_raw(),
),
_ => panic!("unsupported dtype for layernorm"),
}

View File

@@ -14,14 +14,20 @@ pub mod transpose;
pub use activation::{add, bias_add_2d, gelu, gpt_oss_glu, mul, scale, silu, silu_mul};
pub use argmax::{argmax_bf16_single, argmax_bf16_to_host};
pub use transpose::{merge_heads_gpu, repeat_kv_gpu, reshape_heads_gpu, strided_to_contiguous_gpu, transpose_for_rope_gpu, transpose_from_rope_gpu};
pub use attention::{attention, decode_attention, flash_attention, flash_attention_sinks, paged_decode_attention, paged_decode_attention_sinks, reshape_and_cache_bf16, reshape_and_cache_batched_bf16};
pub use attention::{
attention, decode_attention, flash_attention, flash_attention_sinks, paged_decode_attention,
paged_decode_attention_sinks, reshape_and_cache_batched_bf16, reshape_and_cache_bf16,
};
pub use embedding::{embedding, embedding_device_ids};
pub use gemm::{batched_matmul, matmul, GemmBackend};
pub use gemm::{GemmBackend, batched_matmul, matmul};
pub use layernorm::layernorm;
pub use rmsnorm::{add_rmsnorm, rmsnorm};
pub use rope::{rope_inplace, rope_inplace_device_pos, RopeCache};
pub use rope::{RopeCache, rope_inplace, rope_inplace_device_pos};
pub use softmax::softmax;
pub use transpose::{
merge_heads_gpu, repeat_kv_gpu, reshape_heads_gpu, strided_to_contiguous_gpu,
transpose_for_rope_gpu, transpose_from_rope_gpu,
};
/// Register GPU kernels with the tensor crate. Call once at startup.
pub fn init() {

View File

@@ -1,66 +1,113 @@
use std::ffi::c_void;
use xserv_tensor::{DType, Tensor};
use crate::gemm::{cublas_handle, CublasHandle};
use crate::gemm::{CublasHandle, cublas_handle};
unsafe extern "C" {
fn launch_moe_topk_softmax_bf16(
router_logits: *const c_void,
topk_ids: *mut c_void, topk_weights: *mut c_void,
num_tokens: i32, num_experts: i32, top_k: i32,
topk_ids: *mut c_void,
topk_weights: *mut c_void,
num_tokens: i32,
num_experts: i32,
top_k: i32,
stream: *mut c_void,
);
fn launch_moe_replicate_bf16(
x: *const c_void, x_rep: *mut c_void,
num_tokens: i32, hidden: i32, local_experts: i32,
x: *const c_void,
x_rep: *mut c_void,
num_tokens: i32,
hidden: i32,
local_experts: i32,
stream: *mut c_void,
);
fn launch_moe_bias_add_3d_bf16(
x: *mut c_void, bias: *const c_void,
batch: i32, num_tokens: i32, dim: i32,
x: *mut c_void,
bias: *const c_void,
batch: i32,
num_tokens: i32,
dim: i32,
stream: *mut c_void,
);
fn launch_moe_weighted_sum_bf16(
expert_out: *const c_void,
topk_ids: *const c_void, topk_weights: *const c_void,
topk_ids: *const c_void,
topk_weights: *const c_void,
out: *mut c_void,
num_tokens: i32, hidden: i32, top_k: i32,
expert_start: i32, local_experts: i32,
num_tokens: i32,
hidden: i32,
top_k: i32,
expert_start: i32,
local_experts: i32,
stream: *mut c_void,
);
fn launch_moe_sparse_gemv_fp8_bf16(
x: *const c_void, w: *const c_void, w_scales: *const c_void,
bias: *const c_void, topk_ids: *const c_void, y: *mut c_void,
num_tokens: i32, n: i32, k: i32, top_k: i32,
expert_start: i32, local_experts: i32, x_per_slot: i32,
x: *const c_void,
w: *const c_void,
w_scales: *const c_void,
bias: *const c_void,
topk_ids: *const c_void,
y: *mut c_void,
num_tokens: i32,
n: i32,
k: i32,
top_k: i32,
expert_start: i32,
local_experts: i32,
x_per_slot: i32,
stream: *mut c_void,
);
fn launch_moe_sparse_gemv_mxfp4_bf16(
x: *const c_void, w_packed: *const c_void, w_scales: *const c_void,
bias: *const c_void, topk_ids: *const c_void, y: *mut c_void,
num_tokens: i32, n: i32, k: i32, top_k: i32,
expert_start: i32, local_experts: i32, x_per_slot: i32,
x: *const c_void,
w_packed: *const c_void,
w_scales: *const c_void,
bias: *const c_void,
topk_ids: *const c_void,
y: *mut c_void,
num_tokens: i32,
n: i32,
k: i32,
top_k: i32,
expert_start: i32,
local_experts: i32,
x_per_slot: i32,
stream: *mut c_void,
);
fn launch_moe_weighted_sum_sparse_bf16(
down: *const c_void,
topk_ids: *const c_void, topk_weights: *const c_void,
topk_ids: *const c_void,
topk_weights: *const c_void,
out: *mut c_void,
num_tokens: i32, hidden: i32, top_k: i32,
expert_start: i32, local_experts: i32,
num_tokens: i32,
hidden: i32,
top_k: i32,
expert_start: i32,
local_experts: i32,
stream: *mut c_void,
);
fn cublasGemmStridedBatchedEx(
handle: CublasHandle,
transa: i32, transb: i32,
m: i32, n: i32, k: i32,
transa: i32,
transb: i32,
m: i32,
n: i32,
k: i32,
alpha: *const c_void,
a: *const c_void, a_type: i32, lda: i32, stride_a: i64,
b: *const c_void, b_type: i32, ldb: i32, stride_b: i64,
a: *const c_void,
a_type: i32,
lda: i32,
stride_a: i64,
b: *const c_void,
b_type: i32,
ldb: i32,
stride_b: i64,
beta: *const c_void,
c: *mut c_void, c_type: i32, ldc: i32, stride_c: i64,
c: *mut c_void,
c_type: i32,
ldc: i32,
stride_c: i64,
batch_count: i32,
compute_type: i32,
algo: i32,
@@ -99,7 +146,9 @@ pub fn moe_topk_softmax(
router_logits.data_ptr() as *const c_void,
topk_ids.data_ptr() as *mut c_void,
topk_weights.data_ptr() as *mut c_void,
num_tokens as i32, num_experts as i32, top_k as i32,
num_tokens as i32,
num_experts as i32,
top_k as i32,
xserv_cuda::current_stream_raw(),
);
}
@@ -114,13 +163,19 @@ pub fn moe_replicate(x: &Tensor, local_experts: usize) -> Tensor {
assert!(x.is_contiguous());
let num_tokens = x.shape()[0];
let hidden = x.shape()[1];
let out = Tensor::empty(&[local_experts, num_tokens, hidden], DType::BF16, x.device());
let out = Tensor::empty(
&[local_experts, num_tokens, hidden],
DType::BF16,
x.device(),
);
unsafe {
launch_moe_replicate_bf16(
x.data_ptr() as *const c_void,
out.data_ptr() as *mut c_void,
num_tokens as i32, hidden as i32, local_experts as i32,
num_tokens as i32,
hidden as i32,
local_experts as i32,
xserv_cuda::current_stream_raw(),
);
}
@@ -143,7 +198,9 @@ pub fn moe_bias_add_3d(x: &Tensor, bias: &Tensor) {
launch_moe_bias_add_3d_bf16(
x.data_ptr() as *mut c_void,
bias.data_ptr() as *const c_void,
batch as i32, num_tokens as i32, dim as i32,
batch as i32,
num_tokens as i32,
dim as i32,
xserv_cuda::current_stream_raw(),
);
}
@@ -175,8 +232,11 @@ pub fn moe_weighted_sum(
topk_ids.data_ptr() as *const c_void,
topk_weights.data_ptr() as *const c_void,
out.data_ptr() as *mut c_void,
num_tokens as i32, hidden as i32, top_k as i32,
expert_start as i32, local_experts as i32,
num_tokens as i32,
hidden as i32,
top_k as i32,
expert_start as i32,
local_experts as i32,
xserv_cuda::current_stream_raw(),
);
}
@@ -198,9 +258,16 @@ pub fn moe_weighted_sum(
/// consumer must skip them (see moe_weighted_sum_sparse).
#[allow(clippy::too_many_arguments)]
pub fn moe_sparse_gemv_fp8(
x: &Tensor, w_fp8_t: &Tensor, w_scales: &Tensor, bias: &Tensor,
topk_ids: &Tensor, num_tokens: usize, top_k: usize,
expert_start: usize, local_experts: usize, x_per_slot: bool,
x: &Tensor,
w_fp8_t: &Tensor,
w_scales: &Tensor,
bias: &Tensor,
topk_ids: &Tensor,
num_tokens: usize,
top_k: usize,
expert_start: usize,
local_experts: usize,
x_per_slot: bool,
) -> Tensor {
assert_eq!(x.dtype(), DType::BF16);
assert!(x.is_contiguous());
@@ -211,7 +278,14 @@ pub fn moe_sparse_gemv_fp8(
// silently skip a K%16 tail.
assert_eq!(k % 16, 0, "sparse FP8 GEMV requires K % 16 == 0, got {k}");
assert_eq!(x.shape()[x.ndim() - 1], k);
assert_eq!(x.shape()[0], if x_per_slot { num_tokens * top_k } else { num_tokens });
assert_eq!(
x.shape()[0],
if x_per_slot {
num_tokens * top_k
} else {
num_tokens
}
);
let y = Tensor::empty(&[num_tokens, top_k, n], DType::BF16, x.device());
unsafe {
@@ -222,8 +296,13 @@ pub fn moe_sparse_gemv_fp8(
bias.data_ptr() as *const c_void,
topk_ids.data_ptr() as *const c_void,
y.data_ptr() as *mut c_void,
num_tokens as i32, n as i32, k as i32, top_k as i32,
expert_start as i32, local_experts as i32, x_per_slot as i32,
num_tokens as i32,
n as i32,
k as i32,
top_k as i32,
expert_start as i32,
local_experts as i32,
x_per_slot as i32,
xserv_cuda::current_stream_raw(),
);
}
@@ -234,16 +313,32 @@ pub fn moe_sparse_gemv_fp8(
/// with packed 4-bit weights [E, N, K/2] + UE8M0 block scales [E, N, K/32].
#[allow(clippy::too_many_arguments)]
pub fn moe_sparse_gemv_mxfp4(
x: &Tensor, w_packed: &Tensor, w_scales: &Tensor, bias: &Tensor,
topk_ids: &Tensor, num_tokens: usize, top_k: usize, n: usize, k: usize,
expert_start: usize, local_experts: usize, x_per_slot: bool,
x: &Tensor,
w_packed: &Tensor,
w_scales: &Tensor,
bias: &Tensor,
topk_ids: &Tensor,
num_tokens: usize,
top_k: usize,
n: usize,
k: usize,
expert_start: usize,
local_experts: usize,
x_per_slot: bool,
) -> Tensor {
assert_eq!(x.dtype(), DType::BF16);
assert!(x.is_contiguous());
// 32-element MXFP4 blocks, read as uint4 (32 nibbles) per lane.
assert_eq!(k % 32, 0, "sparse MXFP4 GEMV requires K % 32 == 0, got {k}");
assert_eq!(x.shape()[x.ndim() - 1], k);
assert_eq!(x.shape()[0], if x_per_slot { num_tokens * top_k } else { num_tokens });
assert_eq!(
x.shape()[0],
if x_per_slot {
num_tokens * top_k
} else {
num_tokens
}
);
let y = Tensor::empty(&[num_tokens, top_k, n], DType::BF16, x.device());
unsafe {
@@ -254,8 +349,13 @@ pub fn moe_sparse_gemv_mxfp4(
bias.data_ptr() as *const c_void,
topk_ids.data_ptr() as *const c_void,
y.data_ptr() as *mut c_void,
num_tokens as i32, n as i32, k as i32, top_k as i32,
expert_start as i32, local_experts as i32, x_per_slot as i32,
num_tokens as i32,
n as i32,
k as i32,
top_k as i32,
expert_start as i32,
local_experts as i32,
x_per_slot as i32,
xserv_cuda::current_stream_raw(),
);
}
@@ -286,8 +386,11 @@ pub fn moe_weighted_sum_sparse(
topk_ids.data_ptr() as *const c_void,
topk_weights.data_ptr() as *const c_void,
out.data_ptr() as *mut c_void,
num_tokens as i32, hidden as i32, top_k as i32,
expert_start as i32, local_experts as i32,
num_tokens as i32,
hidden as i32,
top_k as i32,
expert_start as i32,
local_experts as i32,
xserv_cuda::current_stream_raw(),
);
}
@@ -341,13 +444,25 @@ pub fn batched_gemm_strided(a: &Tensor, b: &Tensor) -> Tensor {
cublasSetStream_v2(handle, xserv_cuda::current_stream_raw());
let status = cublasGemmStridedBatchedEx(
handle,
0, 0, // CUBLAS_OP_N, CUBLAS_OP_N
n as i32, m as i32, k as i32,
0,
0, // CUBLAS_OP_N, CUBLAS_OP_N
n as i32,
m as i32,
k as i32,
&alpha as *const f32 as *const c_void,
b.data_ptr() as *const c_void, CUDA_R_16BF, n as i32, stride_b,
a.data_ptr() as *const c_void, CUDA_R_16BF, k as i32, stride_a,
b.data_ptr() as *const c_void,
CUDA_R_16BF,
n as i32,
stride_b,
a.data_ptr() as *const c_void,
CUDA_R_16BF,
k as i32,
stride_a,
&beta as *const f32 as *const c_void,
c.data_ptr() as *mut c_void, CUDA_R_16BF, n as i32, stride_c,
c.data_ptr() as *mut c_void,
CUDA_R_16BF,
n as i32,
stride_c,
batch as i32,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT,

View File

@@ -13,30 +13,46 @@ unsafe extern "C" {
src: *const c_void,
scales: *const c_void,
dst: *mut c_void,
num_experts: i32, rows: i32, cols: i32,
num_experts: i32,
rows: i32,
cols: i32,
stream: *mut c_void,
);
fn launch_quantize_bf16_to_fp8e4m3_rowwise(
src: *const c_void,
dst: *mut c_void,
scales: *mut c_void,
num_rows: i32, cols: i32,
num_rows: i32,
cols: i32,
stream: *mut c_void,
);
fn launch_rowwise_scale_moe_bf16(
data: *mut c_void,
a_scales: *const c_void,
b_scales: *const c_void,
num_rows: i32, cols: i32, tokens: i32,
num_rows: i32,
cols: i32,
tokens: i32,
stream: *mut c_void,
);
fn launch_batched_gemv_mxfp4_bf16(
x: *const c_void, w_packed: *const c_void, w_scales: *const c_void, y: *mut c_void,
e: i32, n: i32, k: i32, stream: *mut c_void,
x: *const c_void,
w_packed: *const c_void,
w_scales: *const c_void,
y: *mut c_void,
e: i32,
n: i32,
k: i32,
stream: *mut c_void,
);
fn launch_dequant_mxfp4_to_bf16_t(
w_packed: *const c_void, w_scales: *const c_void, out: *mut c_void,
e: i32, n: i32, k: i32, stream: *mut c_void,
w_packed: *const c_void,
w_scales: *const c_void,
out: *mut c_void,
e: i32,
n: i32,
k: i32,
stream: *mut c_void,
);
}
@@ -66,34 +82,68 @@ struct CublasLtMatmulHeuristicResult {
unsafe extern "C" {
fn cublasLtCreate(handle: *mut CublasLtHandle) -> i32;
fn cublasLtDestroy(handle: CublasLtHandle) -> i32;
fn cublasLtMatmulDescCreate(desc: *mut CublasLtMatmulDesc, compute_type: i32, scale_type: i32) -> i32;
fn cublasLtMatmulDescCreate(
desc: *mut CublasLtMatmulDesc,
compute_type: i32,
scale_type: i32,
) -> i32;
fn cublasLtMatmulDescDestroy(desc: CublasLtMatmulDesc) -> i32;
fn cublasLtMatmulDescSetAttribute(desc: CublasLtMatmulDesc, attr: i32, buf: *const c_void, size: usize) -> i32;
fn cublasLtMatrixLayoutCreate(layout: *mut CublasLtMatrixLayout, dtype: i32, rows: u64, cols: u64, ld: i64) -> i32;
fn cublasLtMatmulDescSetAttribute(
desc: CublasLtMatmulDesc,
attr: i32,
buf: *const c_void,
size: usize,
) -> i32;
fn cublasLtMatrixLayoutCreate(
layout: *mut CublasLtMatrixLayout,
dtype: i32,
rows: u64,
cols: u64,
ld: i64,
) -> i32;
fn cublasLtMatrixLayoutDestroy(layout: CublasLtMatrixLayout) -> i32;
fn cublasLtMatrixLayoutSetAttribute(layout: CublasLtMatrixLayout, attr: i32, buf: *const c_void, size: usize) -> i32;
fn cublasLtMatrixLayoutSetAttribute(
layout: CublasLtMatrixLayout,
attr: i32,
buf: *const c_void,
size: usize,
) -> i32;
fn cublasLtMatmulPreferenceCreate(pref: *mut CublasLtMatmulPreference) -> i32;
fn cublasLtMatmulPreferenceDestroy(pref: CublasLtMatmulPreference) -> i32;
fn cublasLtMatmulPreferenceSetAttribute(pref: CublasLtMatmulPreference, attr: i32, buf: *const c_void, size: usize) -> i32;
fn cublasLtMatmulPreferenceSetAttribute(
pref: CublasLtMatmulPreference,
attr: i32,
buf: *const c_void,
size: usize,
) -> i32;
fn cublasLtMatmulAlgoGetHeuristic(
handle: CublasLtHandle, desc: CublasLtMatmulDesc,
a_layout: CublasLtMatrixLayout, b_layout: CublasLtMatrixLayout,
c_layout: CublasLtMatrixLayout, d_layout: CublasLtMatrixLayout,
handle: CublasLtHandle,
desc: CublasLtMatmulDesc,
a_layout: CublasLtMatrixLayout,
b_layout: CublasLtMatrixLayout,
c_layout: CublasLtMatrixLayout,
d_layout: CublasLtMatrixLayout,
pref: CublasLtMatmulPreference,
requested: i32,
results: *mut CublasLtMatmulHeuristicResult,
found: *mut i32,
) -> i32;
fn cublasLtMatmul(
handle: CublasLtHandle, desc: CublasLtMatmulDesc,
handle: CublasLtHandle,
desc: CublasLtMatmulDesc,
alpha: *const c_void,
a: *const c_void, a_layout: CublasLtMatrixLayout,
b: *const c_void, b_layout: CublasLtMatrixLayout,
a: *const c_void,
a_layout: CublasLtMatrixLayout,
b: *const c_void,
b_layout: CublasLtMatrixLayout,
beta: *const c_void,
c: *const c_void, c_layout: CublasLtMatrixLayout,
d: *mut c_void, d_layout: CublasLtMatrixLayout,
c: *const c_void,
c_layout: CublasLtMatrixLayout,
d: *mut c_void,
d_layout: CublasLtMatrixLayout,
algo: *const CublasLtMatmulAlgo,
workspace: *mut c_void, workspace_size: usize,
workspace: *mut c_void,
workspace_size: usize,
stream: *mut c_void,
) -> i32;
}
@@ -153,8 +203,15 @@ impl CublasLtContext {
assert_eq!(status, 0, "cublasLtCreate failed: {status}");
let workspace = GpuBuffer::alloc(WORKSPACE_BYTES).expect("alloc cublasLt workspace");
let mut one_buf = GpuBuffer::alloc(4).expect("alloc cublasLt fp8 scale");
one_buf.copy_from_host(&1.0f32.to_le_bytes()).expect("init fp8 scale");
Self { handle, workspace, one_buf, plans: HashMap::new() }
one_buf
.copy_from_host(&1.0f32.to_le_bytes())
.expect("init fp8 scale");
Self {
handle,
workspace,
one_buf,
plans: HashMap::new(),
}
}
/// Get the cached strided-batched plan for (m, n, k, batch), building it on
@@ -210,10 +267,25 @@ unsafe fn build_fp8_plan(
// transA=T (required for FP8 on Blackwell)
let trans_a: i32 = 1;
cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_a as *const i32 as _, 4);
cublasLtMatmulDescSetAttribute(
desc,
CUBLASLT_MATMUL_DESC_TRANSA,
&trans_a as *const i32 as _,
4,
);
let ptr_sz = std::mem::size_of::<*const c_void>();
cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &one_ptr as *const _ as _, ptr_sz);
cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &one_ptr as *const _ as _, ptr_sz);
cublasLtMatmulDescSetAttribute(
desc,
CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&one_ptr as *const _ as _,
ptr_sz,
);
cublasLtMatmulDescSetAttribute(
desc,
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&one_ptr as *const _ as _,
ptr_sz,
);
// Per-expert strides in ELEMENTS for the strided-batch layout.
let stride_a = (n * k) as i64; // weights [N, K]
@@ -221,10 +293,18 @@ unsafe fn build_fp8_plan(
let stride_c = (m * n) as i64; // output [M, N]
let bc = batch as i32;
let set_batch = |layout: CublasLtMatrixLayout, stride: i64| {
cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
&bc as *const i32 as _, 4);
cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
&stride as *const i64 as _, 8);
cublasLtMatrixLayoutSetAttribute(
layout,
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
&bc as *const i32 as _,
4,
);
cublasLtMatrixLayoutSetAttribute(
layout,
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
&stride as *const i64 as _,
8,
);
};
// "A" layout (weights, transposed): physical (K, N) col-major, ld=K
@@ -246,20 +326,39 @@ unsafe fn build_fp8_plan(
let mut pref: CublasLtMatmulPreference = std::ptr::null_mut();
cublasLtMatmulPreferenceCreate(&mut pref);
let ws_bytes = WORKSPACE_BYTES as u64;
cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &ws_bytes as *const u64 as _, 8);
cublasLtMatmulPreferenceSetAttribute(
pref,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&ws_bytes as *const u64 as _,
8,
);
let mut heuristic = std::mem::zeroed::<CublasLtMatmulHeuristicResult>();
let mut found: i32 = 0;
let status = cublasLtMatmulAlgoGetHeuristic(
handle, desc, a_layout, b_layout, c_layout, d_layout,
pref, 1, &mut heuristic, &mut found,
handle,
desc,
a_layout,
b_layout,
c_layout,
d_layout,
pref,
1,
&mut heuristic,
&mut found,
);
assert!(
status == 0 && found > 0,
"cublasLtMatmulAlgoGetHeuristic failed for batched FP8 GEMM (m={m}, n={n}, k={k}, batch={batch}): status={status}, found={found}"
);
assert!(status == 0 && found > 0,
"cublasLtMatmulAlgoGetHeuristic failed for batched FP8 GEMM (m={m}, n={n}, k={k}, batch={batch}): status={status}, found={found}");
cublasLtMatmulPreferenceDestroy(pref);
Fp8Plan {
desc, a_layout, b_layout, c_layout, d_layout,
desc,
a_layout,
b_layout,
c_layout,
d_layout,
algo: heuristic.algo,
workspace_size: heuristic.workspace_size,
}
@@ -299,7 +398,9 @@ pub fn dequant_fp8_to_bf16(src: &Tensor, scales: &Tensor) -> Tensor {
src.data_ptr() as *const c_void,
scales.data_ptr() as *const c_void,
out.data_ptr() as *mut c_void,
num_experts as i32, rows as i32, cols as i32,
num_experts as i32,
rows as i32,
cols as i32,
xserv_cuda::current_stream_raw(),
);
}
@@ -329,7 +430,8 @@ pub fn quantize_bf16_to_fp8_rowwise(src: &Tensor) -> (Tensor, Tensor) {
src.data_ptr() as *const c_void,
fp8_out.data_ptr() as *mut c_void,
scales.data_ptr() as *mut c_void,
num_rows as i32, cols as i32,
num_rows as i32,
cols as i32,
xserv_cuda::current_stream_raw(),
);
}
@@ -392,23 +494,27 @@ pub fn batched_gemm_fp8(
unsafe {
let status = cublasLtMatmul(
handle, plan.desc,
handle,
plan.desc,
&alpha as *const f32 as _,
b_fp8_t.data_ptr() as *const c_void, // cuBLASLt "A" = weights
plan.a_layout,
a_fp8.data_ptr() as *const c_void, // cuBLASLt "B" = activations
a_fp8.data_ptr() as *const c_void, // cuBLASLt "B" = activations
plan.b_layout,
&beta as *const f32 as _,
c.data_ptr() as *const c_void, // C (unused with beta=0)
c.data_ptr() as *const c_void, // C (unused with beta=0)
plan.c_layout,
c.data_ptr() as *mut c_void, // D = output
c.data_ptr() as *mut c_void, // D = output
plan.d_layout,
&plan.algo,
ws_ptr,
plan.workspace_size,
xserv_cuda::current_stream_raw(),
);
assert_eq!(status, 0, "batched cublasLtMatmul FP8 failed: status={status}");
assert_eq!(
status, 0,
"batched cublasLtMatmul FP8 failed: status={status}"
);
}
});
@@ -423,7 +529,9 @@ pub fn batched_gemm_fp8(
c.data_ptr() as *mut c_void,
a_scales.data_ptr() as *const c_void,
b_scales.data_ptr() as *const c_void,
total_rows, n as i32, m as i32,
total_rows,
n as i32,
m as i32,
xserv_cuda::current_stream_raw(),
);
}
@@ -442,7 +550,13 @@ pub fn batched_gemm_fp8(
/// w_scales: [E, N, K/32] byte tensor — UE8M0 scale per 32-element block
///
/// Returns: [E, N] BF16, where y[e,n] = sum_k x[e,k] * dequant(W[e,n,k]).
pub fn batched_gemv_mxfp4(x: &Tensor, w_packed: &Tensor, w_scales: &Tensor, n: usize, k: usize) -> Tensor {
pub fn batched_gemv_mxfp4(
x: &Tensor,
w_packed: &Tensor,
w_scales: &Tensor,
n: usize,
k: usize,
) -> Tensor {
assert_eq!(x.dtype(), DType::BF16);
assert!(x.is_contiguous());
let e = x.shape()[0];
@@ -455,7 +569,9 @@ pub fn batched_gemv_mxfp4(x: &Tensor, w_packed: &Tensor, w_scales: &Tensor, n: u
w_packed.data_ptr() as *const c_void,
w_scales.data_ptr() as *const c_void,
y.data_ptr() as *mut c_void,
e as i32, n as i32, k as i32,
e as i32,
n as i32,
k as i32,
xserv_cuda::current_stream_raw(),
);
}
@@ -464,14 +580,22 @@ pub fn batched_gemv_mxfp4(x: &Tensor, w_packed: &Tensor, w_scales: &Tensor, n: u
/// Dequantize MXFP4 weights [E, N, K] → BF16 [E, K, N] for the prefill GEMM path
/// (the BF16 batched GEMM expects weights as [E, K, N]).
pub fn dequant_mxfp4_to_bf16_t(w_packed: &Tensor, w_scales: &Tensor, e: usize, n: usize, k: usize) -> Tensor {
pub fn dequant_mxfp4_to_bf16_t(
w_packed: &Tensor,
w_scales: &Tensor,
e: usize,
n: usize,
k: usize,
) -> Tensor {
let out = Tensor::empty(&[e, k, n], DType::BF16, w_packed.device());
unsafe {
launch_dequant_mxfp4_to_bf16_t(
w_packed.data_ptr() as *const c_void,
w_scales.data_ptr() as *const c_void,
out.data_ptr() as *mut c_void,
e as i32, n as i32, k as i32,
e as i32,
n as i32,
k as i32,
xserv_cuda::current_stream_raw(),
);
}

View File

@@ -2,13 +2,35 @@ use std::ffi::c_void;
use xserv_tensor::{DType, Device, Tensor};
unsafe extern "C" {
fn launch_rmsnorm_f32(x: *const c_void, gamma: *const c_void, out: *mut c_void,
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
fn launch_rmsnorm_bf16(x: *const c_void, gamma: *const c_void, out: *mut c_void,
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
fn launch_add_rmsnorm_bf16(x: *const c_void, residual: *const c_void, gamma: *const c_void,
normed_out: *mut c_void, sum_out: *mut c_void,
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
fn launch_rmsnorm_f32(
x: *const c_void,
gamma: *const c_void,
out: *mut c_void,
rows: i32,
hidden_size: i32,
eps: f32,
stream: *mut c_void,
);
fn launch_rmsnorm_bf16(
x: *const c_void,
gamma: *const c_void,
out: *mut c_void,
rows: i32,
hidden_size: i32,
eps: f32,
stream: *mut c_void,
);
fn launch_add_rmsnorm_bf16(
x: *const c_void,
residual: *const c_void,
gamma: *const c_void,
normed_out: *mut c_void,
sum_out: *mut c_void,
rows: i32,
hidden_size: i32,
eps: f32,
stream: *mut c_void,
);
}
pub fn rmsnorm(x: &Tensor, gamma: &Tensor, eps: f32) -> Tensor {
@@ -20,19 +42,35 @@ pub fn rmsnorm(x: &Tensor, gamma: &Tensor, eps: f32) -> Tensor {
assert_eq!(x.dtype(), gamma.dtype());
let rows = x.numel() / hidden_size;
assert!(rows <= i32::MAX as usize, "too many rows for i32 kernel param");
assert!(hidden_size <= i32::MAX as usize, "hidden_size too large for i32 kernel param");
assert!(
rows <= i32::MAX as usize,
"too many rows for i32 kernel param"
);
assert!(
hidden_size <= i32::MAX as usize,
"hidden_size too large for i32 kernel param"
);
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
unsafe {
match x.dtype() {
DType::F32 => launch_rmsnorm_f32(
x.data_ptr() as _, gamma.data_ptr() as _, out.data_ptr() as *mut c_void,
rows as i32, hidden_size as i32, eps, xserv_cuda::current_stream_raw(),
x.data_ptr() as _,
gamma.data_ptr() as _,
out.data_ptr() as *mut c_void,
rows as i32,
hidden_size as i32,
eps,
xserv_cuda::current_stream_raw(),
),
DType::BF16 => launch_rmsnorm_bf16(
x.data_ptr() as _, gamma.data_ptr() as _, out.data_ptr() as *mut c_void,
rows as i32, hidden_size as i32, eps, xserv_cuda::current_stream_raw(),
x.data_ptr() as _,
gamma.data_ptr() as _,
out.data_ptr() as *mut c_void,
rows as i32,
hidden_size as i32,
eps,
xserv_cuda::current_stream_raw(),
),
_ => panic!("unsupported dtype for rmsnorm"),
}
@@ -56,8 +94,14 @@ pub fn add_rmsnorm(x: &Tensor, residual: &Tensor, gamma: &Tensor, eps: f32) -> (
assert_eq!(gamma.shape(), &[hidden_size]);
let rows = x.numel() / hidden_size;
assert!(rows <= i32::MAX as usize, "too many rows for i32 kernel param");
assert!(hidden_size <= i32::MAX as usize, "hidden_size too large for i32 kernel param");
assert!(
rows <= i32::MAX as usize,
"too many rows for i32 kernel param"
);
assert!(
hidden_size <= i32::MAX as usize,
"hidden_size too large for i32 kernel param"
);
let normed_out = Tensor::empty(x.shape(), DType::BF16, x.device());
let sum_out = Tensor::empty(x.shape(), DType::BF16, x.device());

View File

@@ -3,15 +3,34 @@ use xserv_cuda::GpuBuffer;
use xserv_tensor::{DType, Device, Tensor};
unsafe extern "C" {
fn launch_rope_f32(x: *mut c_void, cos_cache: *const c_void, sin_cache: *const c_void,
positions: *const c_void, num_tokens: i32, num_heads: i32,
head_dim: i32, stream: *mut c_void);
fn launch_rope_bf16(x: *mut c_void, cos_cache: *const c_void, sin_cache: *const c_void,
positions: *const c_void, num_tokens: i32, num_heads: i32,
head_dim: i32, stream: *mut c_void);
fn launch_compute_rope_cache(cos_cache: *mut c_void, sin_cache: *mut c_void,
max_seq_len: i32, half_dim: i32, theta: f32,
stream: *mut c_void);
fn launch_rope_f32(
x: *mut c_void,
cos_cache: *const c_void,
sin_cache: *const c_void,
positions: *const c_void,
num_tokens: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
);
fn launch_rope_bf16(
x: *mut c_void,
cos_cache: *const c_void,
sin_cache: *const c_void,
positions: *const c_void,
num_tokens: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
);
fn launch_compute_rope_cache(
cos_cache: *mut c_void,
sin_cache: *mut c_void,
max_seq_len: i32,
half_dim: i32,
theta: f32,
stream: *mut c_void,
);
}
pub struct RopeCache {
@@ -30,12 +49,21 @@ impl RopeCache {
unsafe {
launch_compute_rope_cache(
cos.as_mut_ptr() as _, sin.as_mut_ptr() as _,
max_seq_len as i32, half_dim as i32, theta, xserv_cuda::current_stream_raw(),
cos.as_mut_ptr() as _,
sin.as_mut_ptr() as _,
max_seq_len as i32,
half_dim as i32,
theta,
xserv_cuda::current_stream_raw(),
);
}
Self { cos, sin, max_seq_len, half_dim }
Self {
cos,
sin,
max_seq_len,
half_dim,
}
}
/// YaRN (Yet another RoPE extensioN) RoPE cache. Applies frequency-dependent
@@ -68,8 +96,8 @@ impl RopeCache {
let mut inv_freq = vec![0.0f64; half_dim];
for i in 0..half_dim {
let pos_freq = theta.powf((2 * i) as f64 / dim);
let inv_freq_extrapolation = 1.0 / pos_freq; // original
let inv_freq_interpolation = 1.0 / (factor * pos_freq); // scaled
let inv_freq_extrapolation = 1.0 / pos_freq; // original
let inv_freq_interpolation = 1.0 / (factor * pos_freq); // scaled
// Linear ramp: 0 where we keep original, 1 where we interpolate
let ramp = if (high - low).abs() < 0.001 {
@@ -101,16 +129,19 @@ impl RopeCache {
let nbytes = total * std::mem::size_of::<f32>();
let mut cos = GpuBuffer::alloc(nbytes).expect("alloc yarn cos_cache");
let mut sin = GpuBuffer::alloc(nbytes).expect("alloc yarn sin_cache");
let cos_bytes = unsafe {
std::slice::from_raw_parts(cos_host.as_ptr() as *const u8, nbytes)
};
let sin_bytes = unsafe {
std::slice::from_raw_parts(sin_host.as_ptr() as *const u8, nbytes)
};
let cos_bytes =
unsafe { std::slice::from_raw_parts(cos_host.as_ptr() as *const u8, nbytes) };
let sin_bytes =
unsafe { std::slice::from_raw_parts(sin_host.as_ptr() as *const u8, nbytes) };
cos.copy_from_host(cos_bytes).unwrap();
sin.copy_from_host(sin_bytes).unwrap();
Self { cos, sin, max_seq_len, half_dim }
Self {
cos,
sin,
max_seq_len,
half_dim,
}
}
}
@@ -133,7 +164,8 @@ pub fn rope_inplace(x: &Tensor, cache: &RopeCache, positions: &[u32]) {
num_tokens * std::mem::size_of::<u32>(),
)
};
let mut pos_gpu = xserv_cuda::allocator::cached_alloc(pos_bytes.len()).expect("alloc positions");
let mut pos_gpu =
xserv_cuda::allocator::cached_alloc(pos_bytes.len()).expect("alloc positions");
pos_gpu.copy_from_host(pos_bytes).unwrap();
rope_inplace_device_pos(x, cache, pos_gpu.as_ptr() as *const c_void);
@@ -155,16 +187,22 @@ pub fn rope_inplace_device_pos(x: &Tensor, cache: &RopeCache, pos_gpu: *const c_
match x.dtype() {
DType::F32 => launch_rope_f32(
x.data_ptr() as *mut c_void,
cache.cos.as_ptr() as _, cache.sin.as_ptr() as _,
cache.cos.as_ptr() as _,
cache.sin.as_ptr() as _,
pos_gpu,
num_tokens as i32, num_heads as i32, head_dim as i32,
num_tokens as i32,
num_heads as i32,
head_dim as i32,
xserv_cuda::current_stream_raw(),
),
DType::BF16 => launch_rope_bf16(
x.data_ptr() as *mut c_void,
cache.cos.as_ptr() as _, cache.sin.as_ptr() as _,
cache.cos.as_ptr() as _,
cache.sin.as_ptr() as _,
pos_gpu,
num_tokens as i32, num_heads as i32, head_dim as i32,
num_tokens as i32,
num_heads as i32,
head_dim as i32,
xserv_cuda::current_stream_raw(),
),
_ => panic!("unsupported dtype for rope"),

View File

@@ -2,8 +2,20 @@ use std::ffi::c_void;
use xserv_tensor::{DType, Device, Tensor};
unsafe extern "C" {
fn launch_softmax_f32(x: *const c_void, out: *mut c_void, rows: i32, cols: i32, stream: *mut c_void);
fn launch_softmax_bf16(x: *const c_void, out: *mut c_void, rows: i32, cols: i32, stream: *mut c_void);
fn launch_softmax_f32(
x: *const c_void,
out: *mut c_void,
rows: i32,
cols: i32,
stream: *mut c_void,
);
fn launch_softmax_bf16(
x: *const c_void,
out: *mut c_void,
rows: i32,
cols: i32,
stream: *mut c_void,
);
}
/// Softmax along the last dimension.
@@ -14,19 +26,31 @@ pub fn softmax(x: &Tensor) -> Tensor {
let cols = *x.shape().last().unwrap();
let rows = x.numel() / cols;
assert!(rows <= i32::MAX as usize, "too many rows for i32 kernel param");
assert!(cols <= i32::MAX as usize, "cols too large for i32 kernel param");
assert!(
rows <= i32::MAX as usize,
"too many rows for i32 kernel param"
);
assert!(
cols <= i32::MAX as usize,
"cols too large for i32 kernel param"
);
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
unsafe {
match x.dtype() {
DType::F32 => launch_softmax_f32(
x.data_ptr() as _, out.data_ptr() as *mut c_void,
rows as i32, cols as i32, xserv_cuda::current_stream_raw(),
x.data_ptr() as _,
out.data_ptr() as *mut c_void,
rows as i32,
cols as i32,
xserv_cuda::current_stream_raw(),
),
DType::BF16 => launch_softmax_bf16(
x.data_ptr() as _, out.data_ptr() as *mut c_void,
rows as i32, cols as i32, xserv_cuda::current_stream_raw(),
x.data_ptr() as _,
out.data_ptr() as *mut c_void,
rows as i32,
cols as i32,
xserv_cuda::current_stream_raw(),
),
_ => panic!("unsupported dtype for softmax"),
}

View File

@@ -2,19 +2,79 @@ use std::ffi::c_void;
use xserv_tensor::{DType, Device, Tensor};
unsafe extern "C" {
fn launch_reshape_heads_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
fn launch_merge_heads_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
fn launch_transpose_hsd_to_shd_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
fn launch_transpose_shd_to_hsd_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
fn launch_repeat_kv_bf16(inp: *const c_void, out: *mut c_void, kv_heads: i32, n_rep: i32, seq_len: i32, head_dim: i32, stream: *mut c_void);
fn launch_strided_copy_bf16(inp: *const c_void, out: *mut c_void, numel: i32, ndim: i32,
shape0: i32, shape1: i32, shape2: i32, shape3: i32,
in_stride0: i32, in_stride1: i32, in_stride2: i32, in_stride3: i32,
in_offset: i32, stream: *mut c_void);
fn launch_strided_copy_f32(inp: *const c_void, out: *mut c_void, numel: i32, ndim: i32,
shape0: i32, shape1: i32, shape2: i32, shape3: i32,
in_stride0: i32, in_stride1: i32, in_stride2: i32, in_stride3: i32,
in_offset: i32, stream: *mut c_void);
fn launch_reshape_heads_bf16(
inp: *const c_void,
out: *mut c_void,
seq_len: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
);
fn launch_merge_heads_bf16(
inp: *const c_void,
out: *mut c_void,
seq_len: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
);
fn launch_transpose_hsd_to_shd_bf16(
inp: *const c_void,
out: *mut c_void,
seq_len: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
);
fn launch_transpose_shd_to_hsd_bf16(
inp: *const c_void,
out: *mut c_void,
seq_len: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
);
fn launch_repeat_kv_bf16(
inp: *const c_void,
out: *mut c_void,
kv_heads: i32,
n_rep: i32,
seq_len: i32,
head_dim: i32,
stream: *mut c_void,
);
fn launch_strided_copy_bf16(
inp: *const c_void,
out: *mut c_void,
numel: i32,
ndim: i32,
shape0: i32,
shape1: i32,
shape2: i32,
shape3: i32,
in_stride0: i32,
in_stride1: i32,
in_stride2: i32,
in_stride3: i32,
in_offset: i32,
stream: *mut c_void,
);
fn launch_strided_copy_f32(
inp: *const c_void,
out: *mut c_void,
numel: i32,
ndim: i32,
shape0: i32,
shape1: i32,
shape2: i32,
shape3: i32,
in_stride0: i32,
in_stride1: i32,
in_stride2: i32,
in_stride3: i32,
in_offset: i32,
stream: *mut c_void,
);
}
/// [S, H*D] → [1, H, S, D] on GPU (BF16)
@@ -24,8 +84,12 @@ pub fn reshape_heads_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim:
let out = Tensor::empty(&[1, num_heads, seq_len, head_dim], DType::BF16, x.device());
unsafe {
launch_reshape_heads_bf16(
x.data_ptr() as _, out.data_ptr() as *mut c_void,
seq_len as i32, num_heads as i32, head_dim as i32, xserv_cuda::current_stream_raw(),
x.data_ptr() as _,
out.data_ptr() as *mut c_void,
seq_len as i32,
num_heads as i32,
head_dim as i32,
xserv_cuda::current_stream_raw(),
);
}
out
@@ -39,36 +103,58 @@ pub fn merge_heads_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: u
let out = Tensor::empty(&[seq_len, hidden], DType::BF16, x.device());
unsafe {
launch_merge_heads_bf16(
x.data_ptr() as _, out.data_ptr() as *mut c_void,
seq_len as i32, num_heads as i32, head_dim as i32, xserv_cuda::current_stream_raw(),
x.data_ptr() as _,
out.data_ptr() as *mut c_void,
seq_len as i32,
num_heads as i32,
head_dim as i32,
xserv_cuda::current_stream_raw(),
);
}
out
}
/// [1, H, S, D] → [S, H, D] for RoPE on GPU (BF16)
pub fn transpose_for_rope_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: usize) -> Tensor {
pub fn transpose_for_rope_gpu(
x: &Tensor,
seq_len: usize,
num_heads: usize,
head_dim: usize,
) -> Tensor {
assert_eq!(x.dtype(), DType::BF16);
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
let out = Tensor::empty(&[seq_len, num_heads, head_dim], DType::BF16, x.device());
unsafe {
launch_transpose_hsd_to_shd_bf16(
x.data_ptr() as _, out.data_ptr() as *mut c_void,
seq_len as i32, num_heads as i32, head_dim as i32, xserv_cuda::current_stream_raw(),
x.data_ptr() as _,
out.data_ptr() as *mut c_void,
seq_len as i32,
num_heads as i32,
head_dim as i32,
xserv_cuda::current_stream_raw(),
);
}
out
}
/// [S, H, D] → [1, H, S, D] after RoPE on GPU (BF16)
pub fn transpose_from_rope_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: usize) -> Tensor {
pub fn transpose_from_rope_gpu(
x: &Tensor,
seq_len: usize,
num_heads: usize,
head_dim: usize,
) -> Tensor {
assert_eq!(x.dtype(), DType::BF16);
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
let out = Tensor::empty(&[1, num_heads, seq_len, head_dim], DType::BF16, x.device());
unsafe {
launch_transpose_shd_to_hsd_bf16(
x.data_ptr() as _, out.data_ptr() as *mut c_void,
seq_len as i32, num_heads as i32, head_dim as i32, xserv_cuda::current_stream_raw(),
x.data_ptr() as _,
out.data_ptr() as *mut c_void,
seq_len as i32,
num_heads as i32,
head_dim as i32,
xserv_cuda::current_stream_raw(),
);
}
out
@@ -76,7 +162,9 @@ pub fn transpose_from_rope_gpu(x: &Tensor, seq_len: usize, num_heads: usize, hea
/// [1, KV_H, S, D] → [1, KV_H*n_rep, S, D] on GPU (BF16)
pub fn repeat_kv_gpu(x: &Tensor, n_rep: usize) -> Tensor {
if n_rep == 1 { return x.clone(); }
if n_rep == 1 {
return x.clone();
}
assert_eq!(x.dtype(), DType::BF16);
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
let kv_heads = x.shape()[1];
@@ -86,8 +174,13 @@ pub fn repeat_kv_gpu(x: &Tensor, n_rep: usize) -> Tensor {
let out = Tensor::empty(&[1, new_heads, seq_len, head_dim], DType::BF16, x.device());
unsafe {
launch_repeat_kv_bf16(
x.data_ptr() as _, out.data_ptr() as *mut c_void,
kv_heads as i32, n_rep as i32, seq_len as i32, head_dim as i32, xserv_cuda::current_stream_raw(),
x.data_ptr() as _,
out.data_ptr() as *mut c_void,
kv_heads as i32,
n_rep as i32,
seq_len as i32,
head_dim as i32,
xserv_cuda::current_stream_raw(),
);
}
out
@@ -122,20 +215,41 @@ pub fn strided_to_contiguous_gpu(x: &Tensor) -> Tensor {
unsafe {
match x.dtype() {
DType::BF16 => launch_strided_copy_bf16(
storage_ptr as _, out.data_ptr() as *mut c_void,
numel as i32, ndim as i32,
shape4[0], shape4[1], shape4[2], shape4[3],
strides4[0], strides4[1], strides4[2], strides4[3],
in_offset, xserv_cuda::current_stream_raw(),
storage_ptr as _,
out.data_ptr() as *mut c_void,
numel as i32,
ndim as i32,
shape4[0],
shape4[1],
shape4[2],
shape4[3],
strides4[0],
strides4[1],
strides4[2],
strides4[3],
in_offset,
xserv_cuda::current_stream_raw(),
),
DType::F32 => launch_strided_copy_f32(
storage_ptr as _, out.data_ptr() as *mut c_void,
numel as i32, ndim as i32,
shape4[0], shape4[1], shape4[2], shape4[3],
strides4[0], strides4[1], strides4[2], strides4[3],
in_offset, xserv_cuda::current_stream_raw(),
storage_ptr as _,
out.data_ptr() as *mut c_void,
numel as i32,
ndim as i32,
shape4[0],
shape4[1],
shape4[2],
shape4[3],
strides4[0],
strides4[1],
strides4[2],
strides4[3],
in_offset,
xserv_cuda::current_stream_raw(),
),
_ => panic!(
"strided_to_contiguous_gpu: unsupported dtype {:?}",
x.dtype()
),
_ => panic!("strided_to_contiguous_gpu: unsupported dtype {:?}", x.dtype()),
}
}
out

View File

@@ -1,11 +1,21 @@
use xserv_kernels::*;
use xserv_tensor::{Device, Tensor};
fn init() { xserv_cuda::device::set_device(0).unwrap(); }
fn init() {
xserv_cuda::device::set_device(0).unwrap();
}
fn cpu_attention(q: &[f32], k: &[f32], v: &[f32],
batch: usize, heads: usize, q_len: usize, kv_len: usize, head_dim: usize,
causal: bool) -> Vec<f32> {
fn cpu_attention(
q: &[f32],
k: &[f32],
v: &[f32],
batch: usize,
heads: usize,
q_len: usize,
kv_len: usize,
head_dim: usize,
causal: bool,
) -> Vec<f32> {
let mut out = vec![0.0f32; batch * heads * q_len * head_dim];
let scale = 1.0 / (head_dim as f32).sqrt();
@@ -70,8 +80,13 @@ fn check_close(a: &[f32], b: &[f32], atol: f32, name: &str) {
let mut max_err = 0.0f32;
for (i, (x, y)) in a.iter().zip(b).enumerate() {
let err = (x - y).abs();
if err > max_err { max_err = err; }
assert!(err <= atol, "{name}: mismatch at [{i}]: got {x}, expected {y}, err {err}");
if err > max_err {
max_err = err;
}
assert!(
err <= atol,
"{name}: mismatch at [{i}]: got {x}, expected {y}, err {err}"
);
}
println!("{name}: max_err = {max_err:.6e}");
}
@@ -105,7 +120,9 @@ fn test_batched_matmul() {
for i in 0..m {
for j in 0..n {
let mut s = 0.0f32;
for kk in 0..k { s += a_cpu[i * k + kk] * b_cpu[kk * n + j]; }
for kk in 0..k {
s += a_cpu[i * k + kk] * b_cpu[kk * n + j];
}
expected[i * n + j] = s;
}
}
@@ -116,7 +133,10 @@ fn test_batched_matmul() {
#[test]
fn test_attention_no_causal() {
init();
let b = 1; let h = 2; let s = 8; let d = 16;
let b = 1;
let h = 2;
let s = 8;
let d = 16;
let q_data = make_data(b * h * s * d);
let k_data = make_data(b * h * s * d);
let v_data = make_data(b * h * s * d);
@@ -126,13 +146,21 @@ fn test_attention_no_causal() {
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
let out = attention(&q, &k, &v, false).to_device(Device::Cpu);
check_close(out.as_slice::<f32>(), &expected, 1e-4, "attention_no_causal");
check_close(
out.as_slice::<f32>(),
&expected,
1e-4,
"attention_no_causal",
);
}
#[test]
fn test_attention_causal() {
init();
let b = 1; let h = 2; let s = 16; let d = 32;
let b = 1;
let h = 2;
let s = 16;
let d = 32;
let q_data = make_data(b * h * s * d);
let k_data = make_data(b * h * s * d);
let v_data = make_data(b * h * s * d);
@@ -148,7 +176,10 @@ fn test_attention_causal() {
#[test]
fn test_attention_causal_larger() {
init();
let b = 2; let h = 4; let s = 64; let d = 64;
let b = 2;
let h = 4;
let s = 64;
let d = 64;
let q_data = make_data(b * h * s * d);
let k_data = make_data(b * h * s * d);
let v_data = make_data(b * h * s * d);
@@ -158,18 +189,28 @@ fn test_attention_causal_larger() {
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
let out = attention(&q, &k, &v, true).to_device(Device::Cpu);
check_close(out.as_slice::<f32>(), &expected, 1e-2, "attention_causal_larger");
check_close(
out.as_slice::<f32>(),
&expected,
1e-2,
"attention_causal_larger",
);
}
#[test]
fn test_attention_causal_first_row_sees_only_first_token() {
init();
let b = 1; let h = 1; let s = 4; let d = 8;
let b = 1;
let h = 1;
let s = 4;
let d = 8;
let q_data = make_data(b * h * s * d);
let k_data = make_data(b * h * s * d);
let v_data: Vec<f32> = (0..s * d).map(|i| {
if i < d { 1.0 } else { 0.0 } // only first V row is nonzero
}).collect();
let v_data: Vec<f32> = (0..s * d)
.map(|i| {
if i < d { 1.0 } else { 0.0 } // only first V row is nonzero
})
.collect();
let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0));
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
@@ -181,7 +222,11 @@ fn test_attention_causal_first_row_sees_only_first_token() {
// output[0] should be exactly V[0] = [1, 1, 1, ...1]
let result = out.as_slice::<f32>();
for i in 0..d {
assert!((result[i] - 1.0).abs() < 1e-5,
"first row should equal V[0], got {} at dim {}", result[i], i);
assert!(
(result[i] - 1.0).abs() < 1e-5,
"first row should equal V[0], got {} at dim {}",
result[i],
i
);
}
}

View File

@@ -1,5 +1,5 @@
use half::bf16;
use xserv_kernels::{matmul, GemmBackend};
use xserv_kernels::{GemmBackend, matmul};
use xserv_tensor::{Device, Tensor};
fn cpu_matmul_f32(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Vec<f32> {
@@ -75,70 +75,110 @@ fn run_gemm_test_bf16(backend: GemmBackend, m: usize, n: usize, k: usize) {
// --- F32 tests ---
#[test]
fn test_gemm_naive_f32_small() { run_gemm_test_f32(GemmBackend::Naive, 4, 4, 4); }
fn test_gemm_naive_f32_small() {
run_gemm_test_f32(GemmBackend::Naive, 4, 4, 4);
}
#[test]
fn test_gemm_naive_f32_medium() { run_gemm_test_f32(GemmBackend::Naive, 64, 64, 64); }
fn test_gemm_naive_f32_medium() {
run_gemm_test_f32(GemmBackend::Naive, 64, 64, 64);
}
#[test]
fn test_gemm_naive_f32_rect() { run_gemm_test_f32(GemmBackend::Naive, 32, 64, 48); }
fn test_gemm_naive_f32_rect() {
run_gemm_test_f32(GemmBackend::Naive, 32, 64, 48);
}
#[test]
fn test_gemm_tiled_f32_small() { run_gemm_test_f32(GemmBackend::Tiled, 4, 4, 4); }
fn test_gemm_tiled_f32_small() {
run_gemm_test_f32(GemmBackend::Tiled, 4, 4, 4);
}
#[test]
fn test_gemm_tiled_f32_medium() { run_gemm_test_f32(GemmBackend::Tiled, 128, 128, 128); }
fn test_gemm_tiled_f32_medium() {
run_gemm_test_f32(GemmBackend::Tiled, 128, 128, 128);
}
#[test]
fn test_gemm_tiled_f32_rect() { run_gemm_test_f32(GemmBackend::Tiled, 65, 33, 97); }
fn test_gemm_tiled_f32_rect() {
run_gemm_test_f32(GemmBackend::Tiled, 65, 33, 97);
}
#[test]
fn test_gemm_cublas_f32_small() { run_gemm_test_f32(GemmBackend::CuBlas, 4, 4, 4); }
fn test_gemm_cublas_f32_small() {
run_gemm_test_f32(GemmBackend::CuBlas, 4, 4, 4);
}
#[test]
fn test_gemm_cublas_f32_medium() { run_gemm_test_f32(GemmBackend::CuBlas, 256, 256, 256); }
fn test_gemm_cublas_f32_medium() {
run_gemm_test_f32(GemmBackend::CuBlas, 256, 256, 256);
}
#[test]
fn test_gemm_cublas_f32_rect() { run_gemm_test_f32(GemmBackend::CuBlas, 65, 33, 97); }
fn test_gemm_cublas_f32_rect() {
run_gemm_test_f32(GemmBackend::CuBlas, 65, 33, 97);
}
// --- BF16 tests ---
#[test]
fn test_gemm_naive_bf16_small() { run_gemm_test_bf16(GemmBackend::Naive, 4, 4, 4); }
fn test_gemm_naive_bf16_small() {
run_gemm_test_bf16(GemmBackend::Naive, 4, 4, 4);
}
#[test]
fn test_gemm_naive_bf16_medium() { run_gemm_test_bf16(GemmBackend::Naive, 64, 64, 64); }
fn test_gemm_naive_bf16_medium() {
run_gemm_test_bf16(GemmBackend::Naive, 64, 64, 64);
}
#[test]
fn test_gemm_tiled_bf16_small() { run_gemm_test_bf16(GemmBackend::Tiled, 4, 4, 4); }
fn test_gemm_tiled_bf16_small() {
run_gemm_test_bf16(GemmBackend::Tiled, 4, 4, 4);
}
#[test]
fn test_gemm_tiled_bf16_medium() { run_gemm_test_bf16(GemmBackend::Tiled, 128, 128, 128); }
fn test_gemm_tiled_bf16_medium() {
run_gemm_test_bf16(GemmBackend::Tiled, 128, 128, 128);
}
#[test]
fn test_gemm_cublas_bf16_small() { run_gemm_test_bf16(GemmBackend::CuBlas, 4, 4, 4); }
fn test_gemm_cublas_bf16_small() {
run_gemm_test_bf16(GemmBackend::CuBlas, 4, 4, 4);
}
#[test]
fn test_gemm_cublas_bf16_medium() { run_gemm_test_bf16(GemmBackend::CuBlas, 256, 256, 256); }
fn test_gemm_cublas_bf16_medium() {
run_gemm_test_bf16(GemmBackend::CuBlas, 256, 256, 256);
}
// --- Custom GEMV tests (M=1, BF16 fast path) ---
#[test]
fn test_gemv_bf16_small() { run_gemm_test_bf16(GemmBackend::CuBlas, 1, 64, 64); }
fn test_gemv_bf16_small() {
run_gemm_test_bf16(GemmBackend::CuBlas, 1, 64, 64);
}
#[test]
fn test_gemv_bf16_medium() { run_gemm_test_bf16(GemmBackend::CuBlas, 1, 256, 256); }
fn test_gemv_bf16_medium() {
run_gemm_test_bf16(GemmBackend::CuBlas, 1, 256, 256);
}
#[test]
fn test_gemv_bf16_4096() { run_gemm_test_bf16(GemmBackend::CuBlas, 1, 4096, 4096); }
fn test_gemv_bf16_4096() {
run_gemm_test_bf16(GemmBackend::CuBlas, 1, 4096, 4096);
}
#[test]
fn test_gemv_bf16_rect() { run_gemm_test_bf16(GemmBackend::CuBlas, 1, 512, 4096); }
fn test_gemv_bf16_rect() {
run_gemm_test_bf16(GemmBackend::CuBlas, 1, 512, 4096);
}
// --- Larger benchmark-style tests ---
#[test]
fn test_gemm_cublas_f32_1024() { run_gemm_test_f32(GemmBackend::CuBlas, 1024, 1024, 1024); }
fn test_gemm_cublas_f32_1024() {
run_gemm_test_f32(GemmBackend::CuBlas, 1024, 1024, 1024);
}
#[test]
fn test_gemm_consistency_all_backends() {

View File

@@ -2,7 +2,9 @@ use half::bf16;
use xserv_kernels::*;
use xserv_tensor::{Device, Tensor};
fn init() { xserv_cuda::device::set_device(0).unwrap(); }
fn init() {
xserv_cuda::device::set_device(0).unwrap();
}
// --- CPU reference implementations ---
@@ -37,10 +39,12 @@ fn cpu_layernorm(x: &[f32], gamma: &[f32], beta: &[f32], eps: f32, hidden: usize
fn cpu_gelu(x: &[f32]) -> Vec<f32> {
let sqrt_2_over_pi = 0.7978845608f32;
x.iter().map(|&v| {
let inner = sqrt_2_over_pi * (v + 0.044715 * v * v * v);
0.5 * v * (1.0 + inner.tanh())
}).collect()
x.iter()
.map(|&v| {
let inner = sqrt_2_over_pi * (v + 0.044715 * v * v * v);
0.5 * v * (1.0 + inner.tanh())
})
.collect()
}
fn cpu_silu(x: &[f32]) -> Vec<f32> {
@@ -88,8 +92,13 @@ fn check_close(result: &[f32], expected: &[f32], atol: f32, name: &str) {
let mut max_err = 0.0f32;
for (i, (r, e)) in result.iter().zip(expected).enumerate() {
let err = (r - e).abs();
if err > max_err { max_err = err; }
assert!(err <= atol, "{name}: mismatch at [{i}]: got {r}, expected {e}, err {err}");
if err > max_err {
max_err = err;
}
assert!(
err <= atol,
"{name}: mismatch at [{i}]: got {r}, expected {e}, err {err}"
);
}
println!("{name}: max_err = {max_err:.6e}");
}
@@ -208,13 +217,18 @@ fn test_softmax_sum_to_one() {
init();
let rows = 4;
let cols = 2048;
let data: Vec<f32> = (0..rows * cols).map(|i| ((i % 31) as f32 - 15.0) * 0.5).collect();
let data: Vec<f32> = (0..rows * cols)
.map(|i| ((i % 31) as f32 - 15.0) * 0.5)
.collect();
let x = Tensor::from_slice(&data, &[rows, cols]).to_device(Device::Cuda(0));
let out = softmax(&x).to_device(Device::Cpu);
let result = out.as_slice::<f32>();
for r in 0..rows {
let row_sum: f32 = result[r * cols..(r + 1) * cols].iter().sum();
assert!((row_sum - 1.0).abs() < 1e-5, "softmax row {r} sum = {row_sum}");
assert!(
(row_sum - 1.0).abs() < 1e-5,
"softmax row {r} sum = {row_sum}"
);
}
}
@@ -247,8 +261,10 @@ fn test_embedding_f32() {
for i in 0..hidden {
let expected = table_data[tid as usize * hidden + i];
let got = result[seq_idx * hidden + i];
assert!((got - expected).abs() < 1e-6,
"embedding mismatch at [{seq_idx},{i}]: got {got}, expected {expected}");
assert!(
(got - expected).abs() < 1e-6,
"embedding mismatch at [{seq_idx},{i}]: got {got}, expected {expected}"
);
}
}
}
@@ -270,8 +286,8 @@ fn test_rope_f32() {
let mut expected = x_data.clone();
cpu_rope(&mut expected, &positions, num_heads, head_dim, theta);
let x = Tensor::from_slice(&x_data, &[num_tokens, num_heads, head_dim])
.to_device(Device::Cuda(0));
let x =
Tensor::from_slice(&x_data, &[num_tokens, num_heads, head_dim]).to_device(Device::Cuda(0));
let cache = RopeCache::new(64, head_dim, theta);
rope_inplace(&x, &cache, &positions);
@@ -292,8 +308,8 @@ fn test_rope_position_0_identity() {
.map(|i| (i as f32 + 1.0) * 0.1)
.collect();
let x = Tensor::from_slice(&x_data, &[num_tokens, num_heads, head_dim])
.to_device(Device::Cuda(0));
let x =
Tensor::from_slice(&x_data, &[num_tokens, num_heads, head_dim]).to_device(Device::Cuda(0));
let cache = RopeCache::new(64, head_dim, 10000.0);
rope_inplace(&x, &cache, &positions);

View File

@@ -3,7 +3,7 @@ use std::sync::Arc;
use std::time::Instant;
use xserv_distributed::{TpContext, UniqueId, get_unique_id};
use xserv_model::{loader, GptOss, GraphedGptOssDecoder, ModelConfig, PagedKVCache, BLOCK_SIZE};
use xserv_model::{BLOCK_SIZE, GptOss, GraphedGptOssDecoder, ModelConfig, PagedKVCache, loader};
use xserv_tensor::{DType, Device};
use xserv_tokenizer::Tokenizer;
@@ -23,8 +23,12 @@ fn main() {
eprintln!(
"gpt-oss-20b: layers={}, hidden={}, heads={}/{} kv, experts={}, top_k={}, vocab={}",
config.num_layers(), config.hidden(), config.num_heads(),
config.num_kv_heads(), config.num_experts(), config.experts_per_token(),
config.num_layers(),
config.hidden(),
config.num_heads(),
config.num_kv_heads(),
config.num_experts(),
config.experts_per_token(),
config.vocab_size
);
eprintln!("TP world={world}, max_tokens={max_tokens}");
@@ -59,17 +63,29 @@ fn main() {
let tp0 = Arc::new(TpContext::init(0, world, uid, 0));
eprintln!("[rank 0] Loading weights...");
let weights = loader::load_model_dir(&model_dir, Device::Cpu);
eprintln!("[rank 0] Loaded {} tensors, building model...", weights.len());
eprintln!(
"[rank 0] Loaded {} tensors, building model...",
weights.len()
);
let model = GptOss::from_weights_tp(config.clone(), weights, 0, world, 0, Some(tp0));
let total_blocks = max_blocks_per_seq + 64;
let mut cache = PagedKVCache::new_tp(
&config, local_kv, total_blocks, 0, 4, max_blocks_per_seq, DType::BF16, 0,
&config,
local_kv,
total_blocks,
0,
4,
max_blocks_per_seq,
DType::BF16,
0,
);
eprintln!("[rank 0] Ready.");
// Prompt
let prompt_arg = get_arg::<String>(&args, "--prompt");
let prompt = prompt_arg.as_deref().unwrap_or("What is the meaning of life?");
let prompt = prompt_arg
.as_deref()
.unwrap_or("What is the meaning of life?");
let token_ids = tokenizer.encode(prompt);
eprintln!("Prompt ({} tokens): {prompt}", token_ids.len());
@@ -83,11 +99,21 @@ fn main() {
// (oracle) next token. Removes free-running compounding so it isolates
// whether per-position logits agree with the llama.cpp trajectory.
if let Some(forced) = get_arg::<String>(&args, "--forced") {
let forced_ids: Vec<u32> = forced.split(',').filter_map(|s| s.trim().parse().ok()).collect();
let forced_ids: Vec<u32> = forced
.split(',')
.filter_map(|s| s.trim().parse().ok())
.collect();
let mut seq = token_ids.clone();
seq.extend_from_slice(&forced_ids);
// Workers must run the same prefill in lockstep (TP AllReduces match up).
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Prefill { tokens: seq.clone(), slot });
broadcast_cmd(
&worker_txs,
&worker_handles,
WorkerCmd::Prefill {
tokens: seq.clone(),
slot,
},
);
let logits = model.forward_prefill_paged(&seq, slot, &mut cache);
wait_workers(&worker_handles);
let logits_cpu = logits.to_device(Device::Cpu);
@@ -99,19 +125,31 @@ fn main() {
// position i predicts seq[i+1]; we check the forced region
for i in (plen - 1)..(seq.len() - 1) {
let row = &data[i * vocab..(i + 1) * vocab];
let argmax = row.iter().enumerate()
let argmax = row
.iter()
.enumerate()
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
.map(|(j, _)| j as u32).unwrap();
.map(|(j, _)| j as u32)
.unwrap();
let expected = seq[i + 1];
let ok = argmax == expected;
if ok { matches += 1; }
if ok {
matches += 1;
}
total += 1;
eprintln!("pos {i}: xserv_argmax={argmax} oracle={expected} {}", if ok {"OK"} else {"DIFF"});
eprintln!(
"pos {i}: xserv_argmax={argmax} oracle={expected} {}",
if ok { "OK" } else { "DIFF" }
);
}
eprintln!("\nTeacher-forced top-1 agreement: {matches}/{total} = {:.1}%",
100.0 * matches as f64 / total as f64);
eprintln!(
"\nTeacher-forced top-1 agreement: {matches}/{total} = {:.1}%",
100.0 * matches as f64 / total as f64
);
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Shutdown);
for (h, _) in worker_handles { h.join().unwrap(); }
for (h, _) in worker_handles {
h.join().unwrap();
}
return;
}
@@ -120,8 +158,18 @@ fn main() {
// per-position top-1 agreement bucketed by position. Localizes long-context
// decode degradation (which prefill teacher-forcing cannot see).
if let Some(forced) = get_arg::<String>(&args, "--forced-decode") {
let forced_ids: Vec<u32> = forced.split(',').filter_map(|s| s.trim().parse().ok()).collect();
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Prefill { tokens: token_ids.clone(), slot });
let forced_ids: Vec<u32> = forced
.split(',')
.filter_map(|s| s.trim().parse().ok())
.collect();
broadcast_cmd(
&worker_txs,
&worker_handles,
WorkerCmd::Prefill {
tokens: token_ids.clone(),
slot,
},
);
let logits = model.forward_prefill_paged(&token_ids, slot, &mut cache);
wait_workers(&worker_handles);
let mut pred = sample_greedy_last(&logits); // prediction for forced[0]
@@ -133,34 +181,55 @@ fn main() {
matches += ok as usize;
total += 1;
let b = i / bucket;
if buckets.len() <= b { buckets.push((0, 0)); }
if buckets.len() <= b {
buckets.push((0, 0));
}
buckets[b].0 += ok as usize;
buckets[b].1 += 1;
// Teacher-force: feed the oracle token through the decode path.
let pos = cache.seq_len(slot);
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Decode {
tokens: vec![f], positions: vec![pos], slots: vec![slot],
});
broadcast_cmd(
&worker_txs,
&worker_handles,
WorkerCmd::Decode {
tokens: vec![f],
positions: vec![pos],
slots: vec![slot],
},
);
let logits = model.forward_decode_paged(&[f], &[pos], &[slot], &mut cache);
wait_workers(&worker_handles);
pred = sample_greedy_last(&logits);
}
eprintln!("Teacher-forced DECODE agreement: {matches}/{total} = {:.1}%",
100.0 * matches as f64 / total as f64);
eprintln!(
"Teacher-forced DECODE agreement: {matches}/{total} = {:.1}%",
100.0 * matches as f64 / total as f64
);
for (b, (m, t)) in buckets.iter().enumerate() {
eprintln!(" pos[{:>4}..{:<4}]: {m:>3}/{t:<3} = {:.0}%",
b * bucket, b * bucket + t, 100.0 * (*m as f64) / (*t as f64));
eprintln!(
" pos[{:>4}..{:<4}]: {m:>3}/{t:<3} = {:.0}%",
b * bucket,
b * bucket + t,
100.0 * (*m as f64) / (*t as f64)
);
}
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Shutdown);
for (h, _) in worker_handles { h.join().unwrap(); }
for (h, _) in worker_handles {
h.join().unwrap();
}
return;
}
// Prefill
let t0 = Instant::now();
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Prefill {
tokens: token_ids.clone(), slot,
});
broadcast_cmd(
&worker_txs,
&worker_handles,
WorkerCmd::Prefill {
tokens: token_ids.clone(),
slot,
},
);
let logits = model.forward_prefill_paged(&token_ids, slot, &mut cache);
wait_workers(&worker_handles);
let ttft = t0.elapsed();
@@ -178,12 +247,20 @@ fn main() {
let text = tokenizer.decode(&[next]);
print!("{text}");
if tokenizer.eos_token_id() == Some(next) { break; }
if tokenizer.eos_token_id() == Some(next) {
break;
}
let pos = cache.seq_len(slot);
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Decode {
tokens: vec![next], positions: vec![pos], slots: vec![slot],
});
broadcast_cmd(
&worker_txs,
&worker_handles,
WorkerCmd::Decode {
tokens: vec![next],
positions: vec![pos],
slots: vec![slot],
},
);
let logits = decoder.decode(&model, &[next], &[pos], &[slot], &mut cache);
wait_workers(&worker_handles);
@@ -196,13 +273,20 @@ fn main() {
let gen_tokens = output_tokens.len();
let full_text = tokenizer.decode(&output_tokens);
eprintln!("\nGenerated text: {full_text}");
eprintln!("Token IDs: {:?}", &output_tokens[..output_tokens.len().min(20)]);
eprintln!(
"Token IDs: {:?}",
&output_tokens[..output_tokens.len().min(20)]
);
let tpot = if gen_tokens > 1 {
decode_elapsed.as_secs_f64() * 1000.0 / (gen_tokens - 1) as f64
} else { 0.0 };
} else {
0.0
};
let tok_s = if gen_tokens > 1 {
(gen_tokens - 1) as f64 / decode_elapsed.as_secs_f64()
} else { 0.0 };
} else {
0.0
};
eprintln!("\n--- Performance ---");
eprintln!("Generated: {} tokens", gen_tokens);
@@ -222,8 +306,15 @@ fn main() {
#[derive(Clone)]
enum WorkerCmd {
Register(usize),
Prefill { tokens: Vec<u32>, slot: usize },
Decode { tokens: Vec<u32>, positions: Vec<usize>, slots: Vec<usize> },
Prefill {
tokens: Vec<u32>,
slot: usize,
},
Decode {
tokens: Vec<u32>,
positions: Vec<usize>,
slots: Vec<usize>,
},
Shutdown,
}
@@ -241,12 +332,20 @@ fn worker_loop(
let tp = Arc::new(TpContext::init(rank, world, uid, rank as u32));
eprintln!("[rank {rank}] Loading weights...");
let weights = loader::load_model_dir(&model_dir, Device::Cpu);
let model = GptOss::from_weights_tp(config.clone(), weights, rank, world, rank as u32, Some(tp));
let model =
GptOss::from_weights_tp(config.clone(), weights, rank, world, rank as u32, Some(tp));
let local_kv = config.num_kv_heads() / world;
let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
let total_blocks = max_blocks_per_seq + 64;
let mut cache = PagedKVCache::new_tp(
&config, local_kv, total_blocks, 0, 4, max_blocks_per_seq, DType::BF16, rank as u32,
&config,
local_kv,
total_blocks,
0,
4,
max_blocks_per_seq,
DType::BF16,
rank as u32,
);
eprintln!("[rank {rank}] Ready.");
ack_tx.send(()).unwrap();
@@ -260,7 +359,11 @@ fn worker_loop(
WorkerCmd::Prefill { tokens, slot } => {
let _ = model.forward_prefill_paged(&tokens, slot, &mut cache);
}
WorkerCmd::Decode { tokens, positions, slots } => {
WorkerCmd::Decode {
tokens,
positions,
slots,
} => {
let _ = decoder.decode(&model, &tokens, &positions, &slots, &mut cache);
}
WorkerCmd::Shutdown => break,
@@ -299,14 +402,15 @@ fn sample_greedy_last(logits: &xserv_tensor::Tensor) -> u32 {
let data = logits_cpu.as_slice::<bf16>();
let last = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size];
last.iter().enumerate()
last.iter()
.enumerate()
.max_by(|a, b| {
let af = a.1.to_f32();
let bf = b.1.to_f32();
af.partial_cmp(&bf).unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i as u32).unwrap()
.map(|(i, _)| i as u32)
.unwrap()
}
fn get_arg<T: std::str::FromStr>(args: &[String], flag: &str) -> Option<T> {

View File

@@ -1,7 +1,7 @@
use std::path::PathBuf;
use std::time::Instant;
use xserv_model::gpt2::{sample_greedy, KVCache};
use xserv_model::{loader, GPT2, ModelConfig};
use xserv_model::gpt2::{KVCache, sample_greedy};
use xserv_model::{GPT2, ModelConfig, loader};
use xserv_tensor::Device;
use xserv_tokenizer::Tokenizer;
@@ -104,9 +104,15 @@ fn main() {
let tbt_us = if !token_times_us.is_empty() {
token_times_us.iter().sum::<u128>() / token_times_us.len() as u128
} else { 0 };
} else {
0
};
let total_gen_us: u128 = ttft_us + token_times_us.iter().sum::<u128>();
let tpot_us = if num_generated > 0 { total_gen_us / num_generated as u128 } else { 0 };
let tpot_us = if num_generated > 0 {
total_gen_us / num_generated as u128
} else {
0
};
let gen_text_escaped = generated_text
.replace('\\', "\\\\")
@@ -124,11 +130,16 @@ fn main() {
print!("\"ttft_us\": {ttft_us}, ");
print!("\"tbt_us\": {tbt_us}, ");
print!("\"tpot_us\": {tpot_us}}}");
if i < prompts.len() - 1 { println!(","); } else { println!(); }
if i < prompts.len() - 1 {
println!(",");
} else {
println!();
}
eprintln!(
"[{}/{}] input={input_len}tok gen={num_generated}tok ttft={:.1}ms tbt={:.1}ms | {}",
i + 1, prompts.len(),
i + 1,
prompts.len(),
ttft_us as f64 / 1000.0,
tbt_us as f64 / 1000.0,
&generated_text.replace('\n', " ")[..generated_text.len().min(60)]
@@ -138,12 +149,18 @@ fn main() {
}
fn generate_with_cache(
model: &GPT2, config: &ModelConfig, tokenizer: &Tokenizer,
input_ids: &[u32], gen_tokens: usize,
model: &GPT2,
config: &ModelConfig,
tokenizer: &Tokenizer,
input_ids: &[u32],
gen_tokens: usize,
) -> (Vec<u32>, u128, Vec<u128>) {
let mut cache = KVCache::new(
config.num_layers(), config.num_heads(), config.head_dim(),
xserv_tensor::DType::F32, Device::Cuda(0),
config.num_layers(),
config.num_heads(),
config.head_dim(),
xserv_tensor::DType::F32,
Device::Cuda(0),
);
// Prefill
@@ -163,15 +180,19 @@ fn generate_with_cache(
let next = sample_greedy(&logits);
token_times.push(t_start.elapsed().as_micros());
generated.push(next);
if tokenizer.eos_token_id() == Some(next) { break; }
if tokenizer.eos_token_id() == Some(next) {
break;
}
}
(generated, ttft_us, token_times)
}
fn generate_no_cache(
model: &GPT2, tokenizer: &Tokenizer,
input_ids: &[u32], gen_tokens: usize,
model: &GPT2,
tokenizer: &Tokenizer,
input_ids: &[u32],
gen_tokens: usize,
) -> (Vec<u32>, u128, Vec<u128>) {
let mut all_ids = input_ids.to_vec();
@@ -191,7 +212,9 @@ fn generate_no_cache(
token_times.push(t_start.elapsed().as_micros());
all_ids.push(next);
generated.push(next);
if tokenizer.eos_token_id() == Some(next) { break; }
if tokenizer.eos_token_id() == Some(next) {
break;
}
}
(generated, ttft_us, token_times)

View File

@@ -1,7 +1,7 @@
use std::path::PathBuf;
use std::time::Instant;
use xserv_model::qwen3::sample_greedy;
use xserv_model::{loader, DecodeGraphState, GpuKVCache, ModelConfig, Qwen3};
use xserv_model::{DecodeGraphState, GpuKVCache, ModelConfig, Qwen3, loader};
use xserv_tensor::{DType, Device};
use xserv_tokenizer::Tokenizer;
@@ -139,18 +139,35 @@ fn main() {
} else {
// Replay captured graphs
let pos = cache.seq_len() as u32;
graph.execute(last, pos, &mut cache, &layer_ptrs, embed, config.vocab_size as i32, config.hidden() as i32);
graph.execute(
last,
pos,
&mut cache,
&layer_ptrs,
embed,
config.vocab_size as i32,
config.hidden() as i32,
);
cache.advance_seq_len(1);
// Read logits from graph buffer
let vocab_size = config.vocab_size;
let mut logits_bytes = vec![0u8; vocab_size * 2];
graph.logits_buffer().copy_to_host(&mut logits_bytes).unwrap();
graph
.logits_buffer()
.copy_to_host(&mut logits_bytes)
.unwrap();
let logits_data: &[half::bf16] = unsafe {
std::slice::from_raw_parts(logits_bytes.as_ptr() as *const half::bf16, vocab_size)
std::slice::from_raw_parts(
logits_bytes.as_ptr() as *const half::bf16,
vocab_size,
)
};
logits_data.iter().enumerate()
logits_data
.iter()
.enumerate()
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
.map(|(idx, _)| idx as u32).unwrap()
.map(|(idx, _)| idx as u32)
.unwrap()
}
} else {
let logits = model.forward_gpu_cache(&[last], &mut cache);
@@ -159,16 +176,24 @@ fn main() {
token_times.push(t_start.elapsed().as_micros());
generated.push(next);
if tokenizer.eos_token_id() == Some(next) { break; }
if tokenizer.eos_token_id() == Some(next) {
break;
}
}
let num_generated = generated.len();
let generated_text = tokenizer.decode(&generated);
let tbt_us = if !token_times.is_empty() {
token_times.iter().sum::<u128>() / token_times.len() as u128
} else { 0 };
} else {
0
};
let total_gen_us: u128 = ttft_us + token_times.iter().sum::<u128>();
let tpot_us = if num_generated > 0 { total_gen_us / num_generated as u128 } else { 0 };
let tpot_us = if num_generated > 0 {
total_gen_us / num_generated as u128
} else {
0
};
let gen_text_escaped = generated_text
.replace('\\', "\\\\")
@@ -186,13 +211,18 @@ fn main() {
print!("\"ttft_us\": {ttft_us}, ");
print!("\"tbt_us\": {tbt_us}, ");
print!("\"tpot_us\": {tpot_us}}}");
if i < prompts.len() - 1 { println!(","); } else { println!(); }
if i < prompts.len() - 1 {
println!(",");
} else {
println!();
}
let display_text = generated_text.replace('\n', " ");
let truncated: String = display_text.chars().take(60).collect();
eprintln!(
"[{}/{}] input={input_len}tok gen={num_generated}tok ttft={:.1}ms tbt={:.1}ms | {}",
i + 1, prompts.len(),
i + 1,
prompts.len(),
ttft_us as f64 / 1000.0,
tbt_us as f64 / 1000.0,
truncated

View File

@@ -18,7 +18,7 @@ use std::thread;
use std::time::Instant;
use xserv_model::qwen3::sample_greedy;
use xserv_model::{loader, ModelConfig, PagedKVCache, Qwen3, BLOCK_SIZE};
use xserv_model::{BLOCK_SIZE, ModelConfig, PagedKVCache, Qwen3, loader};
use xserv_tensor::{DType, Device};
use xserv_tokenizer::Tokenizer;
@@ -35,8 +35,13 @@ fn main() {
std::process::exit(1);
}
let model_dir = PathBuf::from(&args[1]);
let world: usize = arg(&args, "--tp").and_then(|s| s.parse().ok()).unwrap_or(1).max(1);
let gen_tokens: usize = arg(&args, "--gen-tokens").and_then(|s| s.parse().ok()).unwrap_or(64);
let world: usize = arg(&args, "--tp")
.and_then(|s| s.parse().ok())
.unwrap_or(1)
.max(1);
let gen_tokens: usize = arg(&args, "--gen-tokens")
.and_then(|s| s.parse().ok())
.unwrap_or(64);
let devices: Vec<u32> = match arg(&args, "--devices") {
Some(s) => s.split(',').filter_map(|d| d.trim().parse().ok()).collect(),
None => (0..world as u32).collect(),
@@ -67,7 +72,11 @@ fn main() {
// Tensors are not Send (their Storage holds a raw GPU pointer), so each rank
// thread loads its own CPU copy of the weights and shards in-thread. Loading
// is not part of the timed region.
let id = if world > 1 { Some(xserv_distributed::get_unique_id()) } else { None };
let id = if world > 1 {
Some(xserv_distributed::get_unique_id())
} else {
None
};
let handles: Vec<_> = (0..world)
.map(|rank| {
@@ -76,7 +85,9 @@ fn main() {
let prompt_ids = prompt_ids.clone();
let device = devices[rank];
thread::spawn(move || {
run_rank(rank, world, device, id, config, model_dir, prompt_ids, gen_tokens, eos)
run_rank(
rank, world, device, id, config, model_dir, prompt_ids, gen_tokens, eos,
)
})
})
.collect();
@@ -91,7 +102,10 @@ fn main() {
let results = rank0.expect("rank 0 produced no results");
println!("\n=== TP={world} (devices {devices:?}) — Qwen3 E2E benchmark ===");
println!("{:<45} {:>10} {:>12} {:>8}", "prompt", "TTFT(ms)", "decode tok/s", "gen");
println!(
"{:<45} {:>10} {:>12} {:>8}",
"prompt", "TTFT(ms)", "decode tok/s", "gen"
);
let mut tps_sum = 0.0;
for (i, r) in results.iter().enumerate() {
let text = tokenizer.decode(&r.gen_ids).replace('\n', " ");
@@ -99,16 +113,29 @@ fn main() {
let p: String = prompts[i].chars().take(43).collect();
println!(
"{:<45} {:>10.1} {:>12.1} {:>8} | {}",
p, r.ttft_ms, r.decode_tok_s, r.gen_ids.len(), short
p,
r.ttft_ms,
r.decode_tok_s,
r.gen_ids.len(),
short
);
tps_sum += r.decode_tok_s;
}
println!("--- mean decode throughput: {:.1} tok/s ---", tps_sum / results.len() as f64);
println!(
"--- mean decode throughput: {:.1} tok/s ---",
tps_sum / results.len() as f64
);
// Machine-readable line for cross-TP correctness diffing (rank 0 token ids).
let all_ids: Vec<String> = results
.iter()
.map(|r| r.gen_ids.iter().map(|x| x.to_string()).collect::<Vec<_>>().join(","))
.map(|r| {
r.gen_ids
.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join(",")
})
.collect();
println!("CORRECTNESS_IDS tp={world} {}", all_ids.join(" | "));
}
@@ -126,7 +153,12 @@ fn run_rank(
) -> Option<Vec<PromptResult>> {
// Bind this thread to its GPU and set up the TP communicator.
let tp = if world > 1 {
Some(Arc::new(xserv_distributed::TpContext::init(rank, world, id.unwrap(), device)))
Some(Arc::new(xserv_distributed::TpContext::init(
rank,
world,
id.unwrap(),
device,
)))
} else {
xserv_cuda::device::set_device(device).unwrap();
None
@@ -142,7 +174,14 @@ fn run_rank(
let max_blocks_per_seq = max_seq.div_ceil(BLOCK_SIZE);
let total_blocks = max_blocks_per_seq + 8;
let mut cache = PagedKVCache::new_tp(
&config, local_kv, total_blocks, 0, 1, max_blocks_per_seq, DType::BF16, device,
&config,
local_kv,
total_blocks,
0,
1,
max_blocks_per_seq,
DType::BF16,
device,
);
// Warmup (init kernels / allocator / NCCL channels) — not timed.
@@ -177,12 +216,20 @@ fn run_rank(
steps += 1;
}
let decode_s = t1.elapsed().as_secs_f64();
let decode_tok_s = if steps > 0 && decode_s > 0.0 { steps as f64 / decode_s } else { 0.0 };
let decode_tok_s = if steps > 0 && decode_s > 0.0 {
steps as f64 / decode_s
} else {
0.0
};
cache.free_sequence(0);
if rank == 0 {
out.push(PromptResult { gen_ids: generated, ttft_ms, decode_tok_s });
out.push(PromptResult {
gen_ids: generated,
ttft_ms,
decode_tok_s,
});
}
}
@@ -190,5 +237,8 @@ fn run_rank(
}
fn arg<'a>(args: &'a [String], flag: &str) -> Option<&'a str> {
args.iter().position(|a| a == flag).and_then(|i| args.get(i + 1)).map(|s| s.as_str())
args.iter()
.position(|a| a == flag)
.and_then(|i| args.get(i + 1))
.map(|s| s.as_str())
}

View File

@@ -1,8 +1,8 @@
use half::bf16;
use std::path::PathBuf;
use xserv_model::{loader, KVCache, ModelConfig, Qwen3};
use xserv_model::{KVCache, ModelConfig, Qwen3, loader};
use xserv_tensor::{DType, Device};
use xserv_tokenizer::Tokenizer;
use half::bf16;
fn main() {
let args: Vec<String> = std::env::args().collect();
@@ -20,8 +20,11 @@ fn main() {
eprintln!("Token IDs: {token_ids:?}");
let mut cache = KVCache::new(
config.num_layers(), config.num_kv_heads(), config.head_dim(),
DType::BF16, Device::Cuda(0),
config.num_layers(),
config.num_kv_heads(),
config.head_dim(),
DType::BF16,
Device::Cuda(0),
);
let logits = model.forward_with_cache(&token_ids, &mut cache);
let logits_cpu = logits.to_device(Device::Cpu);
@@ -31,7 +34,9 @@ fn main() {
// Print top-20 logits for the last position
let last_row = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size];
let mut indexed: Vec<(usize, f32)> = last_row.iter().enumerate()
let mut indexed: Vec<(usize, f32)> = last_row
.iter()
.enumerate()
.map(|(i, v)| (i, v.to_f32()))
.collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());

View File

@@ -1,10 +1,13 @@
use std::io::{self, IsTerminal, Read, Write};
use std::path::PathBuf;
use std::sync::{mpsc, Arc};
use std::sync::{Arc, mpsc};
use std::thread;
use xserv_model::{GraphedGptOssDecoder, loader, sample, sample_greedy_penalized, GptOss, ModelConfig, PagedKVCache, Qwen3, SamplingParams, BLOCK_SIZE};
use xserv_model::{
BLOCK_SIZE, GptOss, GraphedGptOssDecoder, ModelConfig, PagedKVCache, Qwen3, SamplingParams,
loader, sample, sample_greedy_penalized,
};
use xserv_tensor::{DType, Device};
use xserv_tokenizer::Tokenizer;
@@ -14,13 +17,24 @@ enum ChatModel {
}
impl ChatModel {
fn forward_prefill_paged(&self, tokens: &[u32], slot: usize, cache: &mut PagedKVCache) -> xserv_tensor::Tensor {
fn forward_prefill_paged(
&self,
tokens: &[u32],
slot: usize,
cache: &mut PagedKVCache,
) -> xserv_tensor::Tensor {
match self {
ChatModel::Qwen3(m) => m.forward_prefill_paged(tokens, slot, cache),
ChatModel::GptOss(m) => m.forward_prefill_paged(tokens, slot, cache),
}
}
fn forward_decode_paged(&self, tokens: &[u32], positions: &[usize], slots: &[usize], cache: &mut PagedKVCache) -> xserv_tensor::Tensor {
fn forward_decode_paged(
&self,
tokens: &[u32],
positions: &[usize],
slots: &[usize],
cache: &mut PagedKVCache,
) -> xserv_tensor::Tensor {
match self {
ChatModel::Qwen3(m) => m.forward_decode_paged(tokens, positions, slots, cache),
ChatModel::GptOss(m) => m.forward_decode_paged(tokens, positions, slots, cache),
@@ -33,8 +47,15 @@ impl ChatModel {
enum TpCommand {
Register(usize),
Free(usize),
Prefill { tokens: Vec<u32>, slot: usize },
Decode { tokens: Vec<u32>, positions: Vec<usize>, slots: Vec<usize> },
Prefill {
tokens: Vec<u32>,
slot: usize,
},
Decode {
tokens: Vec<u32>,
positions: Vec<usize>,
slots: Vec<usize>,
},
}
struct TpHandle {
@@ -56,7 +77,8 @@ impl TpHandle {
}
fn tp_worker_loop(
rank: usize, world: usize,
rank: usize,
world: usize,
id: xserv_distributed::UniqueId,
model_dir: std::path::PathBuf,
config: ModelConfig,
@@ -64,29 +86,68 @@ fn tp_worker_loop(
cmd_rx: mpsc::Receiver<TpCommand>,
ack_tx: mpsc::Sender<()>,
) {
let tp = Arc::new(xserv_distributed::TpContext::init(rank, world, id, rank as u32));
let tp = Arc::new(xserv_distributed::TpContext::init(
rank,
world,
id,
rank as u32,
));
let weights = loader::load_model_dir(&model_dir, Device::Cpu);
let model = if config.is_moe() {
ChatModel::GptOss(GptOss::from_weights_tp(config.clone(), weights, rank, world, rank as u32, Some(tp)))
ChatModel::GptOss(GptOss::from_weights_tp(
config.clone(),
weights,
rank,
world,
rank as u32,
Some(tp),
))
} else {
ChatModel::Qwen3(Qwen3::from_weights_tp(config.clone(), weights, rank, world, rank as u32, Some(tp)))
ChatModel::Qwen3(Qwen3::from_weights_tp(
config.clone(),
weights,
rank,
world,
rank as u32,
Some(tp),
))
};
let local_kv = config.num_kv_heads() / world;
let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
let total_blocks = max_blocks_per_seq + 8;
let mut cache = PagedKVCache::new_tp(
&config, local_kv, total_blocks, 0, 1, max_blocks_per_seq, DType::BF16, rank as u32,
&config,
local_kv,
total_blocks,
0,
1,
max_blocks_per_seq,
DType::BF16,
rank as u32,
);
let mut decoder = GraphedGptOssDecoder::new();
while let Ok(cmd) = cmd_rx.recv() {
match cmd {
TpCommand::Register(slot) => { let _ = cache.register_sequence(slot); }
TpCommand::Register(slot) => {
let _ = cache.register_sequence(slot);
}
TpCommand::Free(slot) => cache.free_sequence(slot),
TpCommand::Prefill { tokens, slot } => {
let _ = model.forward_prefill_paged(&tokens, slot, &mut cache);
}
TpCommand::Decode { tokens, positions, slots } => {
let _ = chat_decode(&model, &mut decoder, &tokens, &positions, &slots, &mut cache);
TpCommand::Decode {
tokens,
positions,
slots,
} => {
let _ = chat_decode(
&model,
&mut decoder,
&tokens,
&positions,
&slots,
&mut cache,
);
}
}
let _ = ack_tx.send(());
@@ -221,7 +282,13 @@ fn read_line_edited(prompt: &str) -> Line {
}
b => {
// UTF-8 multi-byte: read the continuation bytes for this char.
let extra = if b >= 0xF0 { 3 } else if b >= 0xE0 { 2 } else { 1 };
let extra = if b >= 0xF0 {
3
} else if b >= 0xE0 {
2
} else {
1
};
let mut bytes = vec![b];
let mut cont = [0u8; 1];
let mut ok = true;
@@ -275,7 +342,8 @@ fn main() {
if world > 1 {
assert!(
config.num_kv_heads() % world == 0,
"num_kv_heads {} not divisible by tp {world}", config.num_kv_heads()
"num_kv_heads {} not divisible by tp {world}",
config.num_kv_heads()
);
}
@@ -290,7 +358,16 @@ fn main() {
let model_dir = opts.model_dir.clone();
let config = config.clone();
thread::spawn(move || {
tp_worker_loop(rank, world, id, model_dir, config, max_seq_len, ctx_rx, ack_tx);
tp_worker_loop(
rank,
world,
id,
model_dir,
config,
max_seq_len,
ctx_rx,
ack_tx,
);
});
}
eprintln!("Loading weights (tp={world})...");
@@ -298,14 +375,37 @@ fn main() {
let weights = loader::load_model_dir(&opts.model_dir, Device::Cpu);
eprintln!("Loaded {} tensors", weights.len());
let m = if is_moe {
ChatModel::GptOss(GptOss::from_weights_tp(config.clone(), weights, 0, world, 0, Some(tp)))
ChatModel::GptOss(GptOss::from_weights_tp(
config.clone(),
weights,
0,
world,
0,
Some(tp),
))
} else {
ChatModel::Qwen3(Qwen3::from_weights_tp(config.clone(), weights, 0, world, 0, Some(tp)))
ChatModel::Qwen3(Qwen3::from_weights_tp(
config.clone(),
weights,
0,
world,
0,
Some(tp),
))
};
let local_kv = config.num_kv_heads() / world;
let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
let total_blocks = max_blocks_per_seq + 8;
let c = PagedKVCache::new_tp(&config, local_kv, total_blocks, 0, 1, max_blocks_per_seq, DType::BF16, 0);
let c = PagedKVCache::new_tp(
&config,
local_kv,
total_blocks,
0,
1,
max_blocks_per_seq,
DType::BF16,
0,
);
let h = TpHandle { cmd_txs, ack_rx };
(m, c, Some(h))
} else {
@@ -323,7 +423,10 @@ fn main() {
let tokenizer = Tokenizer::from_file(&opts.model_dir.join("tokenizer.json"));
let mut decoder = GraphedGptOssDecoder::new();
if let Some(h) = &tp_handle { h.send(TpCommand::Register(SLOT)); h.wait(); }
if let Some(h) = &tp_handle {
h.send(TpCommand::Register(SLOT));
h.wait();
}
cache.register_sequence(SLOT).expect("register chat slot");
let use_color = opts.color && io::stdout().is_terminal();
@@ -365,11 +468,8 @@ fn main() {
if is_moe {
// Harmony multi-turn: re-render the whole conversation (prior
// analysis dropped) and re-prefill into a freshly cleared slot.
let prompt = build_conversation_gpt_oss(
opts.system_prompt.as_deref(),
&moe_history,
input,
);
let prompt =
build_conversation_gpt_oss(opts.system_prompt.as_deref(), &moe_history, input);
let prompt_tokens = tokenizer.encode(&prompt);
if prompt_tokens.is_empty() {
continue;
@@ -386,8 +486,17 @@ fn main() {
print!("assistant> ");
io::stdout().flush().unwrap();
let (_finish, answer) = generate_with_paged_cache(
&model, &mut decoder, &mut cache, &tokenizer, &prompt_tokens, &opts.sampling,
max_new_tokens, use_color, &tp_handle, is_moe, opts.enable_thinking,
&model,
&mut decoder,
&mut cache,
&tokenizer,
&prompt_tokens,
&opts.sampling,
max_new_tokens,
use_color,
&tp_handle,
is_moe,
opts.enable_thinking,
);
moe_history.push((input.to_string(), answer));
println!();
@@ -436,10 +545,24 @@ fn main() {
);
match finish {
Finish::Stop { token_id } => {
append_after_stop(&model, &mut cache, &tokenizer, max_seq_len, token_id, &tp_handle);
append_after_stop(
&model,
&mut cache,
&tokenizer,
max_seq_len,
token_id,
&tp_handle,
);
}
Finish::Length => {
append_text_to_cache(&model, &mut cache, &tokenizer, max_seq_len, "<|im_end|>\n", &tp_handle);
append_text_to_cache(
&model,
&mut cache,
&tokenizer,
max_seq_len,
"<|im_end|>\n",
&tp_handle,
);
}
}
println!();
@@ -448,9 +571,15 @@ fn main() {
/// Free and re-register the single chat KV slot (clears all cached context).
fn reset_slot(cache: &mut PagedKVCache, tp: &Option<TpHandle>) {
if let Some(h) = tp { h.send(TpCommand::Free(SLOT)); h.wait(); }
if let Some(h) = tp {
h.send(TpCommand::Free(SLOT));
h.wait();
}
cache.free_sequence(SLOT);
if let Some(h) = tp { h.send(TpCommand::Register(SLOT)); h.wait(); }
if let Some(h) = tp {
h.send(TpCommand::Register(SLOT));
h.wait();
}
cache.register_sequence(SLOT).expect("register chat slot");
}
@@ -588,7 +717,15 @@ fn new_paged_cache(config: &ModelConfig, max_seq_len: usize) -> PagedKVCache {
let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
let total_blocks = (max_blocks_per_seq + 1).max(2);
// Single-slot interactive CLI: no swap pool (cpu_total_blocks = 0).
PagedKVCache::new(config, total_blocks, 0, 1, max_blocks_per_seq, DType::BF16, 0)
PagedKVCache::new(
config,
total_blocks,
0,
1,
max_blocks_per_seq,
DType::BF16,
0,
)
}
fn build_turn_prompt(
@@ -668,7 +805,10 @@ fn build_conversation_gpt_oss(
/// civil-calendar conversion (same algorithm the server uses for strftime_now).
fn today_ymd() -> String {
use std::time::{SystemTime, UNIX_EPOCH};
let secs = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs();
let secs = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let z = (secs / 86400) as i64 + 719468;
let era = (if z >= 0 { z } else { z - 146096 }) / 146097;
let doe = z - era * 146097;
@@ -709,12 +849,32 @@ fn generate_with_paged_cache(
is_moe: bool,
enable_thinking: bool,
) -> (Finish, String) {
let harmony_end_id = if is_moe { tokenizer.special_token_id("<|end|>") } else { None };
let harmony_channel_id = if is_moe { tokenizer.special_token_id("<|channel|>") } else { None };
let harmony_message_id = if is_moe { tokenizer.special_token_id("<|message|>") } else { None };
let harmony_end_id = if is_moe {
tokenizer.special_token_id("<|end|>")
} else {
None
};
let harmony_channel_id = if is_moe {
tokenizer.special_token_id("<|channel|>")
} else {
None
};
let harmony_message_id = if is_moe {
tokenizer.special_token_id("<|message|>")
} else {
None
};
let harmony_special: Vec<u32> = if is_moe {
["<|channel|>", "<|start|>", "<|end|>", "<|message|>", "<|return|>"]
.iter().filter_map(|s| tokenizer.special_token_id(s)).collect()
[
"<|channel|>",
"<|start|>",
"<|end|>",
"<|message|>",
"<|return|>",
]
.iter()
.filter_map(|s| tokenizer.special_token_id(s))
.collect()
} else {
Vec::new()
};
@@ -722,18 +882,29 @@ fn generate_with_paged_cache(
// "analysis" channel is rendered as thinking (gray). After <|channel|>
// we read the channel name tokens until <|message|>.
#[derive(PartialEq, Clone, Copy)]
enum HarmonyState { Normal, ReadingChannel, InAnalysis, InFinal }
let mut hstate = if is_moe { HarmonyState::InFinal } else { HarmonyState::Normal };
enum HarmonyState {
Normal,
ReadingChannel,
InAnalysis,
InFinal,
}
let mut hstate = if is_moe {
HarmonyState::InFinal
} else {
HarmonyState::Normal
};
// Off by default. A repetition penalty over a harmony stream penalizes the
// control tokens (<|channel|>, <|message|>, <|start|>) that MUST repeat to
// open the final channel — so a non-1.0 default makes gpt-oss stop right
// after the analysis block, before emitting any answer. Opt in via the env
// var if you want it for plain (non-harmony) generation.
let rep_penalty: f32 = std::env::var("XSERV_REP_PENALTY").ok()
let rep_penalty: f32 = std::env::var("XSERV_REP_PENALTY")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(1.0);
let rep_window: usize = std::env::var("XSERV_REP_WINDOW").ok()
let rep_window: usize = std::env::var("XSERV_REP_WINDOW")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(512);
let mut history: Vec<u32> = Vec::new();
@@ -747,9 +918,16 @@ fn generate_with_paged_cache(
}
};
if let Some(h) = tp { h.send(TpCommand::Prefill { tokens: prompt_tokens.to_vec(), slot: SLOT }); }
if let Some(h) = tp {
h.send(TpCommand::Prefill {
tokens: prompt_tokens.to_vec(),
slot: SLOT,
});
}
let logits = model.forward_prefill_paged(prompt_tokens, SLOT, cache);
if let Some(h) = tp { h.wait(); }
if let Some(h) = tp {
h.wait();
}
let mut next = pick(&logits, sampling, &history);
let mut decode_buffer = Vec::new();
let mut in_thinking = false;
@@ -762,9 +940,17 @@ fn generate_with_paged_cache(
for _ in 0..max_tokens {
let position = cache.seq_len(SLOT);
if let Some(h) = tp { h.send(TpCommand::Decode { tokens: vec![next], positions: vec![position], slots: vec![SLOT] }); }
if let Some(h) = tp {
h.send(TpCommand::Decode {
tokens: vec![next],
positions: vec![position],
slots: vec![SLOT],
});
}
let logits = chat_decode(model, decoder, &[next], &[position], &[SLOT], cache);
if let Some(h) = tp { h.wait(); }
if let Some(h) = tp {
h.wait();
}
if tokenizer.is_eos(next) {
print_stream_text(
&tokenizer.flush_decode_stream(&mut decode_buffer),
@@ -775,7 +961,10 @@ fn generate_with_paged_cache(
print_stream_text("\n</think>\n\n", true, use_color);
}
io::stdout().flush().unwrap();
return (Finish::Stop { token_id: next }, tokenizer.decode(&answer_ids));
return (
Finish::Stop { token_id: next },
tokenizer.decode(&answer_ids),
);
}
if harmony_end_id == Some(next) {
// <|end|> closes current segment; if in final channel, we're done
@@ -786,7 +975,10 @@ fn generate_with_paged_cache(
);
if hstate == HarmonyState::InFinal {
io::stdout().flush().unwrap();
return (Finish::Stop { token_id: next }, tokenizer.decode(&answer_ids));
return (
Finish::Stop { token_id: next },
tokenizer.decode(&answer_ids),
);
}
// Closing a thinking (analysis/commentary) channel: emit the </think>
// marker so it renders like Qwen3's thinking block.
@@ -842,7 +1034,13 @@ fn generate_with_paged_cache(
// Analysis channel = the model's reasoning. With --think, show it as a
// thinking block (gray if color); otherwise suppress it (answer only).
if show_thinking {
print_generated_token(tokenizer, next, &mut decode_buffer, &mut in_thinking, use_color);
print_generated_token(
tokenizer,
next,
&mut decode_buffer,
&mut in_thinking,
use_color,
);
io::stdout().flush().unwrap();
}
next = pick(&logits, sampling, &history);
@@ -904,9 +1102,16 @@ fn append_text_to_cache(
if tokens.is_empty() || cache.seq_len(SLOT) + tokens.len() > max_seq_len {
return;
}
if let Some(h) = tp { h.send(TpCommand::Prefill { tokens: tokens.clone(), slot: SLOT }); }
if let Some(h) = tp {
h.send(TpCommand::Prefill {
tokens: tokens.clone(),
slot: SLOT,
});
}
let _ = model.forward_prefill_paged(&tokens, SLOT, cache);
if let Some(h) = tp { h.wait(); }
if let Some(h) = tp {
h.wait();
}
}
fn print_generated_token(
@@ -952,4 +1157,3 @@ fn print_stream_text(text: &str, in_thinking: bool, use_color: bool) {
print!("{text}");
}
}

View File

@@ -1,6 +1,6 @@
use std::io::{self, Write};
use std::path::PathBuf;
use xserv_model::{loader, KVCache, ModelConfig, PagedKVCache, BLOCK_SIZE};
use xserv_model::{BLOCK_SIZE, KVCache, ModelConfig, PagedKVCache, loader};
use xserv_tensor::{DType, Device};
use xserv_tokenizer::Tokenizer;
@@ -21,14 +21,21 @@ fn main() {
xserv_cuda::device::set_device(0).unwrap();
let info = xserv_cuda::device::device_info(0).unwrap();
eprintln!("GPU: {} ({} MB free)", info.name, info.free_memory / 1024 / 1024);
eprintln!(
"GPU: {} ({} MB free)",
info.name,
info.free_memory / 1024 / 1024
);
let config = ModelConfig::from_file(&model_dir.join("config.json"));
let model_type = config.model_type.as_deref().unwrap_or("unknown");
eprintln!(
"Model: {model_type}, layers={}, hidden={}, heads={}/{} kv, vocab={}",
config.num_layers(), config.hidden(), config.num_heads(),
config.num_kv_heads(), config.vocab_size
config.num_layers(),
config.hidden(),
config.num_heads(),
config.num_kv_heads(),
config.vocab_size
);
eprintln!("Loading weights...");
@@ -37,7 +44,11 @@ fn main() {
let is_qwen3 = model_type.contains("qwen");
let is_gpt_oss = model_type.contains("gpt_oss");
let dtype = if is_qwen3 || is_gpt_oss { DType::BF16 } else { DType::F32 };
let dtype = if is_qwen3 || is_gpt_oss {
DType::BF16
} else {
DType::F32
};
// Build model
enum Model {
@@ -60,10 +71,16 @@ fn main() {
print!("xserv> ");
io::stdout().flush().unwrap();
let mut input = String::new();
if io::stdin().read_line(&mut input).unwrap() == 0 { break; }
if io::stdin().read_line(&mut input).unwrap() == 0 {
break;
}
let input = input.trim();
if input.is_empty() { continue; }
if input == "quit" || input == "exit" { break; }
if input.is_empty() {
continue;
}
if input == "quit" || input == "exit" {
break;
}
let token_ids = tokenizer.encode(input);
@@ -73,12 +90,21 @@ fn main() {
let max_blocks_per_seq = (max_seq + BLOCK_SIZE - 1) / BLOCK_SIZE;
let total_blocks = max_blocks_per_seq + 64;
let mut paged_cache = PagedKVCache::new(
&config, total_blocks, 0, 4, max_blocks_per_seq, DType::BF16, 0,
&config,
total_blocks,
0,
4,
max_blocks_per_seq,
DType::BF16,
0,
);
let slot = 0;
paged_cache.register_sequence(slot).expect("register slot");
let model = match &model { Model::GptOss(m) => m, _ => unreachable!() };
let model = match &model {
Model::GptOss(m) => m,
_ => unreachable!(),
};
let logits = model.forward_prefill_paged(&token_ids, slot, &mut paged_cache);
let mut next = sample_greedy_last(&logits);
@@ -90,20 +116,28 @@ fn main() {
print!("{text}");
io::stdout().flush().unwrap();
if tokenizer.eos_token_id() == Some(next) { break; }
if tokenizer.eos_token_id() == Some(next) {
break;
}
let pos = paged_cache.seq_len(slot);
let logits = model.forward_decode_paged(
&[next], &[pos], &[slot], &mut paged_cache,
);
let logits = model.forward_decode_paged(&[next], &[pos], &[slot], &mut paged_cache);
next = sample_greedy_last(&logits);
}
println!();
paged_cache.free_sequence(slot);
} else {
let kv_heads = if is_qwen3 { config.num_kv_heads() } else { config.num_heads() };
let kv_heads = if is_qwen3 {
config.num_kv_heads()
} else {
config.num_heads()
};
let mut cache = KVCache::new(
config.num_layers(), kv_heads, config.head_dim(), dtype, Device::Cuda(0),
config.num_layers(),
kv_heads,
config.head_dim(),
dtype,
Device::Cuda(0),
);
let logits = match &model {
@@ -125,7 +159,9 @@ fn main() {
print!("{text}");
io::stdout().flush().unwrap();
if tokenizer.eos_token_id() == Some(next) { break; }
if tokenizer.eos_token_id() == Some(next) {
break;
}
let logits = match &model {
Model::GPT2(m) => m.forward_with_cache(&[next], &mut cache),
@@ -151,7 +187,9 @@ fn sample_greedy_last(logits: &xserv_tensor::Tensor) -> u32 {
let seq_len = logits.shape()[0];
let data = logits_cpu.as_slice::<bf16>();
let last = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size];
last.iter().enumerate()
last.iter()
.enumerate()
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
.map(|(i, _)| i as u32).unwrap()
.map(|(i, _)| i as u32)
.unwrap()
}

View File

@@ -88,23 +88,33 @@ impl ModelConfig {
}
pub fn hidden(&self) -> usize {
self.hidden_size.or(self.n_embd).expect("hidden_size or n_embd required")
self.hidden_size
.or(self.n_embd)
.expect("hidden_size or n_embd required")
}
pub fn num_heads(&self) -> usize {
self.num_attention_heads.or(self.n_head).expect("num_attention_heads or n_head required")
self.num_attention_heads
.or(self.n_head)
.expect("num_attention_heads or n_head required")
}
pub fn num_layers(&self) -> usize {
self.num_hidden_layers.or(self.n_layer).expect("num_hidden_layers or n_layer required")
self.num_hidden_layers
.or(self.n_layer)
.expect("num_hidden_layers or n_layer required")
}
pub fn max_seq_len(&self) -> usize {
self.max_position_embeddings.or(self.n_positions).unwrap_or(2048)
self.max_position_embeddings
.or(self.n_positions)
.unwrap_or(2048)
}
pub fn ffn_hidden(&self) -> usize {
self.intermediate_size.or(self.n_inner).unwrap_or(self.hidden() * 4)
self.intermediate_size
.or(self.n_inner)
.unwrap_or(self.hidden() * 4)
}
pub fn num_kv_heads(&self) -> usize {
@@ -112,7 +122,8 @@ impl ModelConfig {
}
pub fn head_dim(&self) -> usize {
self.explicit_head_dim.unwrap_or_else(|| self.hidden() / self.num_heads())
self.explicit_head_dim
.unwrap_or_else(|| self.hidden() / self.num_heads())
}
pub fn ln_eps(&self) -> f32 {

View File

@@ -18,19 +18,19 @@ use crate::kv_cache::GpuKVCache;
/// All buffers have stable GPU addresses for CUDA Graph replay.
struct DecodeBuffers {
// Hidden-size buffers: [1, hidden]
x: GpuBuffer, // running hidden state
normed: GpuBuffer, // rmsnorm output
attn_out: GpuBuffer, // attention output [1, num_heads, 1, head_dim]
attn_merged: GpuBuffer, // merge_heads output [1, hidden]
o_proj: GpuBuffer, // O projection output [1, hidden]
normed2: GpuBuffer, // post-attn norm output [1, hidden]
sum_out: GpuBuffer, // add_rmsnorm sum output [1, hidden]
down: GpuBuffer, // down projection output [1, hidden]
x: GpuBuffer, // running hidden state
normed: GpuBuffer, // rmsnorm output
attn_out: GpuBuffer, // attention output [1, num_heads, 1, head_dim]
attn_merged: GpuBuffer, // merge_heads output [1, hidden]
o_proj: GpuBuffer, // O projection output [1, hidden]
normed2: GpuBuffer, // post-attn norm output [1, hidden]
sum_out: GpuBuffer, // add_rmsnorm sum output [1, hidden]
down: GpuBuffer, // down projection output [1, hidden]
// QKV projection outputs
q_proj: GpuBuffer, // [1, num_heads * head_dim]
k_proj: GpuBuffer, // [1, num_kv_heads * head_dim]
v_proj: GpuBuffer, // [1, num_kv_heads * head_dim]
q_proj: GpuBuffer, // [1, num_heads * head_dim]
k_proj: GpuBuffer, // [1, num_kv_heads * head_dim]
v_proj: GpuBuffer, // [1, num_kv_heads * head_dim]
// Reshaped: [1, H, 1, D]
q_reshaped: GpuBuffer,
@@ -50,23 +50,23 @@ struct DecodeBuffers {
k_final: GpuBuffer,
// FFN intermediates
gate: GpuBuffer, // [1, intermediate]
up: GpuBuffer, // [1, intermediate]
silu_out: GpuBuffer, // [1, intermediate]
gate: GpuBuffer, // [1, intermediate]
up: GpuBuffer, // [1, intermediate]
silu_out: GpuBuffer, // [1, intermediate]
// GEMV fp32 accumulators (separate per output dimension)
fp32_hidden: GpuBuffer, // for hidden-sized GEMV outputs
fp32_q: GpuBuffer, // for Q projection
fp32_kv: GpuBuffer, // for K/V projection
fp32_intermediate: GpuBuffer,// for gate/up projections
fp32_vocab: GpuBuffer, // for lm_head
fp32_hidden: GpuBuffer, // for hidden-sized GEMV outputs
fp32_q: GpuBuffer, // for Q projection
fp32_kv: GpuBuffer, // for K/V projection
fp32_intermediate: GpuBuffer, // for gate/up projections
fp32_vocab: GpuBuffer, // for lm_head
// Token ID and position (GPU-resident, updated before replay)
token_id_gpu: GpuBuffer, // 4 bytes (u32)
position_gpu: GpuBuffer, // 4 bytes (u32)
token_id_gpu: GpuBuffer, // 4 bytes (u32)
position_gpu: GpuBuffer, // 4 bytes (u32)
// Final output
logits: GpuBuffer, // [1, vocab_size]
logits: GpuBuffer, // [1, vocab_size]
}
pub struct DecodeGraphState {
@@ -199,127 +199,296 @@ impl DecodeGraphState {
let cublas = cublas_handle();
// Set cuBLAS to use our stream
unsafe { dispatch::set_cublas_stream(cublas, s); }
unsafe {
dispatch::set_cublas_stream(cublas, s);
}
for (l, lw) in layers.iter().enumerate() {
// === Pre-attention graph ===
self.pre_attn_graphs[l].begin_capture(&self.stream).expect("begin pre-attn capture");
self.pre_attn_graphs[l]
.begin_capture(&self.stream)
.expect("begin pre-attn capture");
unsafe {
// RMSNorm
dispatch::rmsnorm_bf16(
self.buffers.x.as_ptr() as _, lw.input_norm, self.buffers.normed.as_mut_ptr() as _,
1, h, eps, s,
self.buffers.x.as_ptr() as _,
lw.input_norm,
self.buffers.normed.as_mut_ptr() as _,
1,
h,
eps,
s,
);
// Q projection (GEMV)
dispatch::gemv_bf16(
self.buffers.normed.as_ptr() as _, lw.q_proj_wt, self.buffers.q_proj.as_mut_ptr() as _,
self.buffers.normed.as_ptr() as _,
lw.q_proj_wt,
self.buffers.q_proj.as_mut_ptr() as _,
self.buffers.fp32_q.as_mut_ptr() as _,
h, nh * hd, s,
h,
nh * hd,
s,
);
// K projection (GEMV)
dispatch::gemv_bf16(
self.buffers.normed.as_ptr() as _, lw.k_proj_wt, self.buffers.k_proj.as_mut_ptr() as _,
self.buffers.normed.as_ptr() as _,
lw.k_proj_wt,
self.buffers.k_proj.as_mut_ptr() as _,
self.buffers.fp32_kv.as_mut_ptr() as _,
h, nkv * hd, s,
h,
nkv * hd,
s,
);
// V projection (GEMV)
dispatch::gemv_bf16(
self.buffers.normed.as_ptr() as _, lw.v_proj_wt, self.buffers.v_proj.as_mut_ptr() as _,
self.buffers.normed.as_ptr() as _,
lw.v_proj_wt,
self.buffers.v_proj.as_mut_ptr() as _,
self.buffers.fp32_kv.as_mut_ptr() as _,
h, nkv * hd, s,
h,
nkv * hd,
s,
);
// Reshape heads: [1, H*D] -> [1, H, 1, D]
dispatch::reshape_heads_bf16(self.buffers.q_proj.as_ptr() as _, self.buffers.q_reshaped.as_mut_ptr() as _, 1, nh, hd, s);
dispatch::reshape_heads_bf16(self.buffers.k_proj.as_ptr() as _, self.buffers.k_reshaped.as_mut_ptr() as _, 1, nkv, hd, s);
dispatch::reshape_heads_bf16(self.buffers.v_proj.as_ptr() as _, self.buffers.v_reshaped.as_mut_ptr() as _, 1, nkv, hd, s);
dispatch::reshape_heads_bf16(
self.buffers.q_proj.as_ptr() as _,
self.buffers.q_reshaped.as_mut_ptr() as _,
1,
nh,
hd,
s,
);
dispatch::reshape_heads_bf16(
self.buffers.k_proj.as_ptr() as _,
self.buffers.k_reshaped.as_mut_ptr() as _,
1,
nkv,
hd,
s,
);
dispatch::reshape_heads_bf16(
self.buffers.v_proj.as_ptr() as _,
self.buffers.v_reshaped.as_mut_ptr() as _,
1,
nkv,
hd,
s,
);
// QK norm (head-level rmsnorm: treat [1,H,1,D] as [H, D])
dispatch::rmsnorm_bf16(self.buffers.q_reshaped.as_ptr() as _, lw.q_norm, self.buffers.q_normed.as_mut_ptr() as _, nh, hd, eps, s);
dispatch::rmsnorm_bf16(self.buffers.k_reshaped.as_ptr() as _, lw.k_norm, self.buffers.k_normed.as_mut_ptr() as _, nkv, hd, eps, s);
dispatch::rmsnorm_bf16(
self.buffers.q_reshaped.as_ptr() as _,
lw.q_norm,
self.buffers.q_normed.as_mut_ptr() as _,
nh,
hd,
eps,
s,
);
dispatch::rmsnorm_bf16(
self.buffers.k_reshaped.as_ptr() as _,
lw.k_norm,
self.buffers.k_normed.as_mut_ptr() as _,
nkv,
hd,
eps,
s,
);
// Transpose for RoPE: [1,H,1,D] -> [1,H,D]
dispatch::transpose_hsd_to_shd_bf16(self.buffers.q_normed.as_ptr() as _, self.buffers.q_rope.as_mut_ptr() as _, 1, nh, hd, s);
dispatch::transpose_hsd_to_shd_bf16(self.buffers.k_normed.as_ptr() as _, self.buffers.k_rope.as_mut_ptr() as _, 1, nkv, hd, s);
dispatch::transpose_hsd_to_shd_bf16(
self.buffers.q_normed.as_ptr() as _,
self.buffers.q_rope.as_mut_ptr() as _,
1,
nh,
hd,
s,
);
dispatch::transpose_hsd_to_shd_bf16(
self.buffers.k_normed.as_ptr() as _,
self.buffers.k_rope.as_mut_ptr() as _,
1,
nkv,
hd,
s,
);
// RoPE (in-place, reads position_gpu)
dispatch::rope_bf16(self.buffers.q_rope.as_mut_ptr() as _, rope_cos, rope_sin, self.buffers.position_gpu.as_ptr() as _, 1, nh, hd, s);
dispatch::rope_bf16(self.buffers.k_rope.as_mut_ptr() as _, rope_cos, rope_sin, self.buffers.position_gpu.as_ptr() as _, 1, nkv, hd, s);
dispatch::rope_bf16(
self.buffers.q_rope.as_mut_ptr() as _,
rope_cos,
rope_sin,
self.buffers.position_gpu.as_ptr() as _,
1,
nh,
hd,
s,
);
dispatch::rope_bf16(
self.buffers.k_rope.as_mut_ptr() as _,
rope_cos,
rope_sin,
self.buffers.position_gpu.as_ptr() as _,
1,
nkv,
hd,
s,
);
// Transpose back: [1,H,D] -> [1,H,1,D]
dispatch::transpose_shd_to_hsd_bf16(self.buffers.q_rope.as_ptr() as _, self.buffers.q_final.as_mut_ptr() as _, 1, nh, hd, s);
dispatch::transpose_shd_to_hsd_bf16(self.buffers.k_rope.as_ptr() as _, self.buffers.k_final.as_mut_ptr() as _, 1, nkv, hd, s);
dispatch::transpose_shd_to_hsd_bf16(
self.buffers.q_rope.as_ptr() as _,
self.buffers.q_final.as_mut_ptr() as _,
1,
nh,
hd,
s,
);
dispatch::transpose_shd_to_hsd_bf16(
self.buffers.k_rope.as_ptr() as _,
self.buffers.k_final.as_mut_ptr() as _,
1,
nkv,
hd,
s,
);
}
self.pre_attn_graphs[l].end_capture(&self.stream).expect("end pre-attn capture");
self.pre_attn_graphs[l]
.end_capture(&self.stream)
.expect("end pre-attn capture");
// === Post-attention graph ===
self.post_attn_graphs[l].begin_capture(&self.stream).expect("begin post-attn capture");
self.post_attn_graphs[l]
.begin_capture(&self.stream)
.expect("begin post-attn capture");
unsafe {
// Merge heads: [1,H,1,D] -> [1, hidden]
// attn_out is written by ungraphed attention
dispatch::merge_heads_bf16(self.buffers.attn_out.as_ptr() as _, self.buffers.attn_merged.as_mut_ptr() as _, 1, nh, hd, s);
dispatch::merge_heads_bf16(
self.buffers.attn_out.as_ptr() as _,
self.buffers.attn_merged.as_mut_ptr() as _,
1,
nh,
hd,
s,
);
// O projection
dispatch::gemv_bf16(
self.buffers.attn_merged.as_ptr() as _, lw.o_proj_wt, self.buffers.o_proj.as_mut_ptr() as _,
self.buffers.attn_merged.as_ptr() as _,
lw.o_proj_wt,
self.buffers.o_proj.as_mut_ptr() as _,
self.buffers.fp32_hidden.as_mut_ptr() as _,
nh * hd, h, s,
nh * hd,
h,
s,
);
// Fused Add+RMSNorm: normed2 = rmsnorm(o_proj + x), sum_out = o_proj + x
dispatch::add_rmsnorm_bf16(
self.buffers.o_proj.as_ptr() as _, self.buffers.x.as_ptr() as _, lw.post_norm,
self.buffers.normed2.as_mut_ptr() as _, self.buffers.sum_out.as_mut_ptr() as _,
1, h, eps, s,
self.buffers.o_proj.as_ptr() as _,
self.buffers.x.as_ptr() as _,
lw.post_norm,
self.buffers.normed2.as_mut_ptr() as _,
self.buffers.sum_out.as_mut_ptr() as _,
1,
h,
eps,
s,
);
// Gate projection
dispatch::gemv_bf16(
self.buffers.normed2.as_ptr() as _, lw.gate_proj_wt, self.buffers.gate.as_mut_ptr() as _,
self.buffers.normed2.as_ptr() as _,
lw.gate_proj_wt,
self.buffers.gate.as_mut_ptr() as _,
self.buffers.fp32_intermediate.as_mut_ptr() as _,
h, inter, s,
h,
inter,
s,
);
// Up projection
dispatch::gemv_bf16(
self.buffers.normed2.as_ptr() as _, lw.up_proj_wt, self.buffers.up.as_mut_ptr() as _,
self.buffers.normed2.as_ptr() as _,
lw.up_proj_wt,
self.buffers.up.as_mut_ptr() as _,
self.buffers.fp32_intermediate.as_mut_ptr() as _,
h, inter, s,
h,
inter,
s,
);
// Fused SiLU x Mul
dispatch::silu_mul_bf16(self.buffers.gate.as_ptr() as _, self.buffers.up.as_ptr() as _, self.buffers.silu_out.as_mut_ptr() as _, inter, s);
dispatch::silu_mul_bf16(
self.buffers.gate.as_ptr() as _,
self.buffers.up.as_ptr() as _,
self.buffers.silu_out.as_mut_ptr() as _,
inter,
s,
);
// Down projection
dispatch::gemv_bf16(
self.buffers.silu_out.as_ptr() as _, lw.down_proj_wt, self.buffers.down.as_mut_ptr() as _,
self.buffers.silu_out.as_ptr() as _,
lw.down_proj_wt,
self.buffers.down.as_mut_ptr() as _,
self.buffers.fp32_hidden.as_mut_ptr() as _,
inter, h, s,
inter,
h,
s,
);
// x = sum_out + down (residual connection for next layer)
dispatch::add_bf16(self.buffers.sum_out.as_ptr() as _, self.buffers.down.as_ptr() as _, self.buffers.x.as_mut_ptr() as _, h, s);
dispatch::add_bf16(
self.buffers.sum_out.as_ptr() as _,
self.buffers.down.as_ptr() as _,
self.buffers.x.as_mut_ptr() as _,
h,
s,
);
}
self.post_attn_graphs[l].end_capture(&self.stream).expect("end post-attn capture");
self.post_attn_graphs[l]
.end_capture(&self.stream)
.expect("end post-attn capture");
}
// === Final graph: norm + lm_head ===
self.final_graph.begin_capture(&self.stream).expect("begin final capture");
self.final_graph
.begin_capture(&self.stream)
.expect("begin final capture");
unsafe {
dispatch::rmsnorm_bf16(self.buffers.x.as_ptr() as _, norm_weight, self.buffers.normed.as_mut_ptr() as _, 1, h, eps, s);
dispatch::rmsnorm_bf16(
self.buffers.x.as_ptr() as _,
norm_weight,
self.buffers.normed.as_mut_ptr() as _,
1,
h,
eps,
s,
);
dispatch::gemv_bf16(
self.buffers.normed.as_ptr() as _, lm_head_wt, self.buffers.logits.as_mut_ptr() as _,
self.buffers.normed.as_ptr() as _,
lm_head_wt,
self.buffers.logits.as_mut_ptr() as _,
self.buffers.fp32_vocab.as_mut_ptr() as _,
h, vocab, s,
h,
vocab,
s,
);
}
self.final_graph.end_capture(&self.stream).expect("end final capture");
self.final_graph
.end_capture(&self.stream)
.expect("end final capture");
// Reset cuBLAS back to null stream
unsafe { dispatch::set_cublas_stream(cublas, std::ptr::null_mut()); }
unsafe {
dispatch::set_cublas_stream(cublas, std::ptr::null_mut());
}
self.captured = true;
}
@@ -343,8 +512,14 @@ impl DecodeGraphState {
let es = 2usize; // BF16
// Upload token ID and position to fixed GPU buffers
self.buffers.token_id_gpu.copy_from_host(&token_id.to_le_bytes()).unwrap();
self.buffers.position_gpu.copy_from_host(&position.to_le_bytes()).unwrap();
self.buffers
.token_id_gpu
.copy_from_host(&token_id.to_le_bytes())
.unwrap();
self.buffers
.position_gpu
.copy_from_host(&position.to_le_bytes())
.unwrap();
// Embedding (outside graph since token_id changes each step)
unsafe {
@@ -352,13 +527,18 @@ impl DecodeGraphState {
embed_table,
self.buffers.token_id_gpu.as_ptr() as _,
self.buffers.x.as_mut_ptr() as _,
1, hidden_size, vocab_size, s,
1,
hidden_size,
vocab_size,
s,
);
}
for l in 0..self.num_layers {
// Pre-attention graph (norm + QKV + reshape + QK-norm + RoPE)
self.pre_attn_graphs[l].launch(&self.stream).expect("launch pre-attn graph");
self.pre_attn_graphs[l]
.launch(&self.stream)
.expect("launch pre-attn graph");
// Ungraphed: KV cache append
// k_final shape: [1, num_kv_heads, 1, head_dim] (after RoPE pipeline)
@@ -402,9 +582,13 @@ impl DecodeGraphState {
k_full.data_ptr() as _,
v_full.data_ptr() as _,
self.buffers.attn_out.as_mut_ptr() as _,
1, nh as i32, nkv as i32,
kv_len, hd as i32,
scale, s,
1,
nh as i32,
nkv as i32,
kv_len,
hd as i32,
scale,
s,
);
}
@@ -412,11 +596,15 @@ impl DecodeGraphState {
self.stream.synchronize().expect("sync before post-attn");
// Post-attention graph (merge + O-proj + add_rmsnorm + FFN + residual)
self.post_attn_graphs[l].launch(&self.stream).expect("launch post-attn graph");
self.post_attn_graphs[l]
.launch(&self.stream)
.expect("launch post-attn graph");
}
// Final graph (norm + lm_head)
self.final_graph.launch(&self.stream).expect("launch final graph");
self.final_graph
.launch(&self.stream)
.expect("launch final graph");
// Sync to ensure logits are ready
self.stream.synchronize().expect("sync after decode");

View File

@@ -31,7 +31,7 @@ struct GPT2Block {
pub struct KVCache {
// Per layer, per head: raw bytes (works for both f32 and bf16)
k: Vec<Vec<Vec<u8>>>, // [num_layers][num_heads][seq_len * head_dim * elem_size]
k: Vec<Vec<Vec<u8>>>, // [num_layers][num_heads][seq_len * head_dim * elem_size]
v: Vec<Vec<Vec<u8>>>,
len: usize,
num_heads: usize,
@@ -42,7 +42,13 @@ pub struct KVCache {
}
impl KVCache {
pub fn new(num_layers: usize, num_heads: usize, head_dim: usize, dtype: DType, device: Device) -> Self {
pub fn new(
num_layers: usize,
num_heads: usize,
head_dim: usize,
dtype: DType,
device: Device,
) -> Self {
Self {
k: (0..num_layers).map(|_| vec![vec![]; num_heads]).collect(),
v: (0..num_layers).map(|_| vec![vec![]; num_heads]).collect(),
@@ -55,10 +61,18 @@ impl KVCache {
}
}
pub fn seq_len(&self) -> usize { self.len }
pub fn seq_len(&self) -> usize {
self.len
}
/// Append from a CPU tensor with shape [1, H, new_tokens, D].
pub fn append_kv_tensor(&mut self, layer: usize, k_cpu: &Tensor, v_cpu: &Tensor, new_tokens: usize) {
pub fn append_kv_tensor(
&mut self,
layer: usize,
k_cpu: &Tensor,
v_cpu: &Tensor,
new_tokens: usize,
) {
let hd = self.head_dim;
let es = self.elem_size;
let k_bytes = k_cpu.storage().as_cpu_bytes();
@@ -118,7 +132,8 @@ impl GPT2 {
pub fn from_weights(config: ModelConfig, mut w: HashMap<String, Tensor>) -> Self {
crate::init_kernels();
let take = |w: &mut HashMap<String, Tensor>, name: &str| -> Tensor {
w.remove(name).unwrap_or_else(|| panic!("missing weight: {name}"))
w.remove(name)
.unwrap_or_else(|| panic!("missing weight: {name}"))
};
let wte = take(&mut w, "wte.weight");
@@ -147,7 +162,15 @@ impl GPT2 {
});
}
Self { config, wte, wpe, layers, ln_f_g, ln_f_b, lm_head }
Self {
config,
wte,
wpe,
layers,
ln_f_g,
ln_f_b,
lm_head,
}
}
/// Full forward pass without KV cache (for testing / correctness comparison).
@@ -179,14 +202,22 @@ impl GPT2 {
let head_dim = self.config.head_dim();
let tok_emb = embedding(&self.wte, token_ids);
let pos_ids: Vec<u32> = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect();
let pos_ids: Vec<u32> = (pos_offset..pos_offset + new_tokens)
.map(|p| p as u32)
.collect();
let pos_emb = embedding(&self.wpe, &pos_ids);
let mut x = add_tensors(&tok_emb, &pos_emb);
for (layer_idx, layer) in self.layers.iter().enumerate() {
x = self.transformer_block(
layer, &x, Some((cache, layer_idx)),
pos_offset, new_tokens, num_heads, head_dim, hidden,
layer,
&x,
Some((cache, layer_idx)),
pos_offset,
new_tokens,
num_heads,
head_dim,
hidden,
);
}
@@ -238,7 +269,11 @@ impl GPT2 {
fn linear(x: &Tensor, weight: &Tensor, bias: Option<&Tensor>) -> Tensor {
let out = matmul_2d(x, weight);
if let Some(b) = bias { add_bias(&out, b) } else { out }
if let Some(b) = bias {
add_bias(&out, b)
} else {
out
}
}
fn matmul_2d(a: &Tensor, b: &Tensor) -> Tensor {
@@ -277,7 +312,12 @@ fn add_bias(x: &Tensor, bias: &Tensor) -> Tensor {
}
}
fn split_qkv(qkv: &Tensor, num_heads: usize, head_dim: usize, seq_len: usize) -> (Tensor, Tensor, Tensor) {
fn split_qkv(
qkv: &Tensor,
num_heads: usize,
head_dim: usize,
seq_len: usize,
) -> (Tensor, Tensor, Tensor) {
let hidden = num_heads * head_dim;
let qkv_cpu = qkv.to_device(Device::Cpu);
let device = qkv.device();
@@ -294,14 +334,21 @@ fn split_qkv(qkv: &Tensor, num_heads: usize, head_dim: usize, seq_len: usize) ->
for h in 0..num_heads {
let src_off = h * head_dim;
let dst_off = (h * seq_len + s) * head_dim;
q_data[dst_off..dst_off + head_dim].copy_from_slice(&row[src_off..src_off + head_dim]);
k_data[dst_off..dst_off + head_dim].copy_from_slice(&row[hidden + src_off..hidden + src_off + head_dim]);
v_data[dst_off..dst_off + head_dim].copy_from_slice(&row[2 * hidden + src_off..2 * hidden + src_off + head_dim]);
q_data[dst_off..dst_off + head_dim]
.copy_from_slice(&row[src_off..src_off + head_dim]);
k_data[dst_off..dst_off + head_dim]
.copy_from_slice(&row[hidden + src_off..hidden + src_off + head_dim]);
v_data[dst_off..dst_off + head_dim].copy_from_slice(
&row[2 * hidden + src_off..2 * hidden + src_off + head_dim],
);
}
}
let q = Tensor::from_slice(&q_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
let k = Tensor::from_slice(&k_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
let v = Tensor::from_slice(&v_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
let q =
Tensor::from_slice(&q_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
let k =
Tensor::from_slice(&k_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
let v =
Tensor::from_slice(&v_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
(q, k, v)
}
DType::BF16 => {
@@ -314,14 +361,21 @@ fn split_qkv(qkv: &Tensor, num_heads: usize, head_dim: usize, seq_len: usize) ->
for h in 0..num_heads {
let src_off = h * head_dim;
let dst_off = (h * seq_len + s) * head_dim;
q_data[dst_off..dst_off + head_dim].copy_from_slice(&row[src_off..src_off + head_dim]);
k_data[dst_off..dst_off + head_dim].copy_from_slice(&row[hidden + src_off..hidden + src_off + head_dim]);
v_data[dst_off..dst_off + head_dim].copy_from_slice(&row[2 * hidden + src_off..2 * hidden + src_off + head_dim]);
q_data[dst_off..dst_off + head_dim]
.copy_from_slice(&row[src_off..src_off + head_dim]);
k_data[dst_off..dst_off + head_dim]
.copy_from_slice(&row[hidden + src_off..hidden + src_off + head_dim]);
v_data[dst_off..dst_off + head_dim].copy_from_slice(
&row[2 * hidden + src_off..2 * hidden + src_off + head_dim],
);
}
}
let q = Tensor::from_slice(&q_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
let k = Tensor::from_slice(&k_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
let v = Tensor::from_slice(&v_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
let q =
Tensor::from_slice(&q_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
let k =
Tensor::from_slice(&k_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
let v =
Tensor::from_slice(&v_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
(q, k, v)
}
_ => panic!("unsupported dtype {:?} in split_qkv", dtype),
@@ -343,7 +397,8 @@ fn merge_heads(x: &Tensor, seq_len: usize, hidden: usize) -> Tensor {
for h in 0..num_heads {
let src_off = (h * seq_len + s) * head_dim;
let dst_off = s * hidden + h * head_dim;
out[dst_off..dst_off + head_dim].copy_from_slice(&src[src_off..src_off + head_dim]);
out[dst_off..dst_off + head_dim]
.copy_from_slice(&src[src_off..src_off + head_dim]);
}
}
Tensor::from_slice(&out, &[seq_len, hidden]).to_device(device)
@@ -355,7 +410,8 @@ fn merge_heads(x: &Tensor, seq_len: usize, hidden: usize) -> Tensor {
for h in 0..num_heads {
let src_off = (h * seq_len + s) * head_dim;
let dst_off = s * hidden + h * head_dim;
out[dst_off..dst_off + head_dim].copy_from_slice(&src[src_off..src_off + head_dim]);
out[dst_off..dst_off + head_dim]
.copy_from_slice(&src[src_off..src_off + head_dim]);
}
}
Tensor::from_slice(&out, &[seq_len, hidden]).to_device(device)
@@ -372,7 +428,8 @@ pub fn sample_greedy(logits: &Tensor) -> u32 {
let vocab_size = logits.shape()[1];
let seq_len = logits.shape()[0];
let last_row = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size];
last_row.iter()
last_row
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.map(|(idx, _)| idx as u32)

View File

@@ -1,6 +1,6 @@
use half::bf16;
use std::collections::HashMap;
use std::ffi::c_void;
use half::bf16;
use xserv_kernels::*;
use xserv_tensor::{Device, Tensor};
@@ -49,10 +49,10 @@ struct GptOssBlock {
expert_down_bias: Tensor, // [local_experts, hidden]
// FP8 quantized expert weights (Some when running FP8 W8A8)
// Transposed layout [E, N, K] for cuBLASLt FP8 (Blackwell requires transA=T)
expert_gate_up_fp8: Option<Tensor>, // [local_experts, 2*inter, hidden] FP8E4M3
expert_gate_up_scale: Option<Tensor>,// [local_experts] F32
expert_down_fp8: Option<Tensor>, // [local_experts, hidden, inter] FP8E4M3
expert_down_scale: Option<Tensor>, // [local_experts] F32
expert_gate_up_fp8: Option<Tensor>, // [local_experts, 2*inter, hidden] FP8E4M3
expert_gate_up_scale: Option<Tensor>, // [local_experts] F32
expert_down_fp8: Option<Tensor>, // [local_experts, hidden, inter] FP8E4M3
expert_down_scale: Option<Tensor>, // [local_experts] F32
// MXFP4 W4A16 expert weights (Some when running 4-bit weight-only).
// (packed [E, N, K/2] u8, scales [E, N, K/32] u8) in [E, N, K] layout.
expert_gate_up_mxfp4: Option<(Tensor, Tensor)>,
@@ -79,16 +79,23 @@ impl GptOss {
crate::init_kernels();
let dev = Device::Cuda(device);
let take = |w: &mut HashMap<String, Tensor>, name: &str| -> Tensor {
w.remove(name).unwrap_or_else(|| panic!("missing weight: {name}"))
w.remove(name)
.unwrap_or_else(|| panic!("missing weight: {name}"))
};
let repl = |t: Tensor| -> Tensor { t.to_device(dev) };
// column-parallel: shard rows of [out, in], transpose → [in, out/world]
let col = |t: Tensor| -> Tensor {
shard_rows(&t, rank, world).to_device(dev).transpose(0, 1).contiguous()
shard_rows(&t, rank, world)
.to_device(dev)
.transpose(0, 1)
.contiguous()
};
// row-parallel: shard cols of [out, in], transpose → [in/world, out]
let row = |t: Tensor| -> Tensor {
shard_cols(&t, rank, world).to_device(dev).transpose(0, 1).contiguous()
shard_cols(&t, rank, world)
.to_device(dev)
.transpose(0, 1)
.contiguous()
};
// Bias sharding helpers
let col_bias = |t: Tensor| -> Tensor { shard_1d(&t, rank, world).to_device(dev) };
@@ -97,7 +104,9 @@ impl GptOss {
let embed_tokens = repl(take(&mut w, "model.embed_tokens.weight"));
let norm = repl(take(&mut w, "model.norm.weight"));
let norm_bias = w.remove("model.norm.bias").map(|t| repl(t));
let lm_head_t = repl(take(&mut w, "lm_head.weight")).transpose(0, 1).contiguous();
let lm_head_t = repl(take(&mut w, "lm_head.weight"))
.transpose(0, 1)
.contiguous();
let head_dim = config.head_dim();
let rope_theta = config.rope_theta.unwrap_or(150000.0);
@@ -176,15 +185,30 @@ impl GptOss {
// MXFP4 stores 4-bit weights in an FP8E4M3 byte container (same dtype
// as FP8), so distinguish by the scale rank: FP8 scale is 1-D [E],
// MXFP4 scale is 3-D [E, N, K/32].
let is_mxfp4 = gate_up_scale.as_ref().map(|s| s.ndim() == 3).unwrap_or(false);
let is_mxfp4 = gate_up_scale
.as_ref()
.map(|s| s.ndim() == 3)
.unwrap_or(false);
let is_fp8 = !is_mxfp4 && gate_up_3d.dtype() == xserv_tensor::DType::FP8E4M3;
let mut expert_gate_up_mxfp4: Option<(Tensor, Tensor)> = None;
let mut expert_down_mxfp4: Option<(Tensor, Tensor)> = None;
let inter2 = if is_mxfp4 { gate_up_3d.shape()[1] } else { gate_up_3d.shape()[2] }; // 2*inter (N)
let hidden = if is_mxfp4 { gate_up_3d.shape()[2] * 2 } else { gate_up_3d.shape()[1] };
let inter = if is_mxfp4 { down_3d.shape()[2] * 2 } else { down_3d.shape()[1] };
let inter2 = if is_mxfp4 {
gate_up_3d.shape()[1]
} else {
gate_up_3d.shape()[2]
}; // 2*inter (N)
let hidden = if is_mxfp4 {
gate_up_3d.shape()[2] * 2
} else {
gate_up_3d.shape()[1]
};
let inter = if is_mxfp4 {
down_3d.shape()[2] * 2
} else {
down_3d.shape()[1]
};
// Slice the rank's range of experts as contiguous 3D tensors on GPU
let expert_gate_up_wt;
@@ -199,10 +223,38 @@ impl GptOss {
// + scales [E, N, K/32]. Slice this rank's experts (raw bytes).
let gu_s = gate_up_scale.expect("MXFP4 model missing gate_up_proj_scale");
let d_s = down_scale.expect("MXFP4 model missing down_proj_scale");
let gu_packed = slice_expert_range_3d_raw(&gate_up_3d, expert_start, local_experts, inter2, hidden / 2).to_device(dev);
let gu_scl = slice_expert_range_3d_raw(&gu_s, expert_start, local_experts, inter2, hidden / 32).to_device(dev);
let dn_packed = slice_expert_range_3d_raw(&down_3d, expert_start, local_experts, hidden, inter / 2).to_device(dev);
let dn_scl = slice_expert_range_3d_raw(&d_s, expert_start, local_experts, hidden, inter / 32).to_device(dev);
let gu_packed = slice_expert_range_3d_raw(
&gate_up_3d,
expert_start,
local_experts,
inter2,
hidden / 2,
)
.to_device(dev);
let gu_scl = slice_expert_range_3d_raw(
&gu_s,
expert_start,
local_experts,
inter2,
hidden / 32,
)
.to_device(dev);
let dn_packed = slice_expert_range_3d_raw(
&down_3d,
expert_start,
local_experts,
hidden,
inter / 2,
)
.to_device(dev);
let dn_scl = slice_expert_range_3d_raw(
&d_s,
expert_start,
local_experts,
hidden,
inter / 32,
)
.to_device(dev);
expert_gate_up_mxfp4 = Some((gu_packed, gu_scl));
expert_down_mxfp4 = Some((dn_packed, dn_scl));
expert_gate_up_fp8 = None;
@@ -214,36 +266,65 @@ impl GptOss {
} else if is_fp8 {
// FP8 W8A8 path: load and TRANSPOSE weights for cuBLASLt (requires transA=T on Blackwell).
// Original: [E, K, N] → Transposed: [E, N, K]
let gu_sliced = slice_expert_range_3d_raw(&gate_up_3d, expert_start, local_experts, hidden, inter2);
let dn_sliced = slice_expert_range_3d_raw(&down_3d, expert_start, local_experts, inter, hidden);
expert_gate_up_fp8 = Some(transpose_3d_inner_raw(&gu_sliced, local_experts, hidden, inter2).to_device(dev));
expert_down_fp8 = Some(transpose_3d_inner_raw(&dn_sliced, local_experts, inter, hidden).to_device(dev));
let gu_sliced = slice_expert_range_3d_raw(
&gate_up_3d,
expert_start,
local_experts,
hidden,
inter2,
);
let dn_sliced =
slice_expert_range_3d_raw(&down_3d, expert_start, local_experts, inter, hidden);
expert_gate_up_fp8 = Some(
transpose_3d_inner_raw(&gu_sliced, local_experts, hidden, inter2)
.to_device(dev),
);
expert_down_fp8 = Some(
transpose_3d_inner_raw(&dn_sliced, local_experts, inter, hidden).to_device(dev),
);
// Scales: [num_experts] F32 → slice to [local_experts]
let gu_s = gate_up_scale.expect("FP8 model missing gate_up_proj_scale");
let d_s = down_scale.expect("FP8 model missing down_proj_scale");
expert_gate_up_scale_gpu = Some(slice_scale_range(&gu_s, expert_start, local_experts).to_device(dev));
expert_down_scale_gpu = Some(slice_scale_range(&d_s, expert_start, local_experts).to_device(dev));
expert_gate_up_scale_gpu =
Some(slice_scale_range(&gu_s, expert_start, local_experts).to_device(dev));
expert_down_scale_gpu =
Some(slice_scale_range(&d_s, expert_start, local_experts).to_device(dev));
// Dummy BF16 tensors (never read in FP8 path)
expert_gate_up_wt = Tensor::empty(&[1, 1, 1], xserv_tensor::DType::BF16, dev);
expert_down_wt = Tensor::empty(&[1, 1, 1], xserv_tensor::DType::BF16, dev);
} else {
// BF16 path: existing behavior
expert_gate_up_wt = slice_expert_range_3d(&gate_up_3d, expert_start, local_experts, hidden, inter2).to_device(dev);
expert_down_wt = slice_expert_range_3d(&down_3d, expert_start, local_experts, inter, hidden).to_device(dev);
expert_gate_up_wt =
slice_expert_range_3d(&gate_up_3d, expert_start, local_experts, hidden, inter2)
.to_device(dev);
expert_down_wt =
slice_expert_range_3d(&down_3d, expert_start, local_experts, inter, hidden)
.to_device(dev);
expert_gate_up_fp8 = None;
expert_gate_up_scale_gpu = None;
expert_down_fp8 = None;
expert_down_scale_gpu = None;
}
let expert_gate_up_bias = slice_expert_range_2d(&gate_up_bias_2d, expert_start, local_experts, inter2).to_device(dev);
let expert_down_bias = slice_expert_range_2d(&down_bias_2d, expert_start, local_experts, hidden).to_device(dev);
let expert_gate_up_bias =
slice_expert_range_2d(&gate_up_bias_2d, expert_start, local_experts, inter2)
.to_device(dev);
let expert_down_bias =
slice_expert_range_2d(&down_bias_2d, expert_start, local_experts, hidden)
.to_device(dev);
xserv_cuda::allocator::cached_trim();
let input_norm = repl(take(&mut w, &format!("{p}.input_layernorm.weight")));
let input_norm_bias = w.remove(&format!("{p}.input_layernorm.bias")).map(|t| repl(t));
let post_norm = repl(take(&mut w, &format!("{p}.post_attention_layernorm.weight")));
let post_norm_bias = w.remove(&format!("{p}.post_attention_layernorm.bias")).map(|t| repl(t));
let input_norm_bias = w
.remove(&format!("{p}.input_layernorm.bias"))
.map(|t| repl(t));
let post_norm = repl(take(
&mut w,
&format!("{p}.post_attention_layernorm.weight"),
));
let post_norm_bias = w
.remove(&format!("{p}.post_attention_layernorm.bias"))
.map(|t| repl(t));
layers.push(GptOssBlock {
input_norm,
@@ -283,17 +364,27 @@ impl GptOss {
let local_num_kv_heads = config.num_kv_heads() / world;
let has_norm_bias = norm_bias.is_some();
let is_fp8 = layers.first().map(|l| l.expert_gate_up_fp8.is_some()).unwrap_or(false);
let is_mxfp4 = layers.first().map(|l| l.expert_gate_up_mxfp4.is_some()).unwrap_or(false);
let is_fp8 = layers
.first()
.map(|l| l.expert_gate_up_fp8.is_some())
.unwrap_or(false);
let is_mxfp4 = layers
.first()
.map(|l| l.expert_gate_up_mxfp4.is_some())
.unwrap_or(false);
if rank == 0 {
if has_norm_bias {
eprintln!("gpt-oss: detected LayerNorm bias — using LayerNorm instead of RMSNorm");
}
if is_fp8 {
eprintln!("gpt-oss: FP8 E4M3 quantized expert weights detected (W8A8 cuBLASLt mode)");
eprintln!(
"gpt-oss: FP8 E4M3 quantized expert weights detected (W8A8 cuBLASLt mode)"
);
}
if is_mxfp4 {
eprintln!("gpt-oss: MXFP4 quantized expert weights detected (W4A16 fused-GEMV mode)");
eprintln!(
"gpt-oss: MXFP4 quantized expert weights detected (W4A16 fused-GEMV mode)"
);
}
}
@@ -341,7 +432,13 @@ impl GptOss {
}
#[inline]
fn add_norm(x: &Tensor, residual: &Tensor, weight: &Tensor, bias: &Option<Tensor>, eps: f32) -> (Tensor, Tensor) {
fn add_norm(
x: &Tensor,
residual: &Tensor,
weight: &Tensor,
bias: &Option<Tensor>,
eps: f32,
) -> (Tensor, Tensor) {
match bias {
Some(b) => {
let sum = xserv_kernels::add(x, residual);
@@ -439,7 +536,6 @@ impl GptOss {
let k_all = add_bias(&matmul_2d(&normed, &layer.k_proj_wt), &layer.k_proj_bias);
let v_all = add_bias(&matmul_2d(&normed, &layer.v_proj_wt), &layer.v_proj_bias);
// Reshape for RoPE: [B, H*D] → [B, H, D]
let q_3d = q_all.reshape(&[batch, num_heads, head_dim]);
let k_3d = k_all.reshape(&[batch, num_kv_heads, head_dim]);
@@ -460,9 +556,17 @@ impl GptOss {
let sinks_ptr = layer.sinks.data_ptr() as *const c_void;
let attn_out = paged_decode_attention_sinks(
&q_4d, k_pool_ptr, v_pool_ptr, bt_ptr, cl_ptr,
&q_4d,
k_pool_ptr,
v_pool_ptr,
bt_ptr,
cl_ptr,
sinks_ptr,
batch, num_heads, num_kv_heads, head_dim, max_blocks,
batch,
num_heads,
num_kv_heads,
head_dim,
max_blocks,
layer.window_size,
);
@@ -471,9 +575,14 @@ impl GptOss {
self.all_reduce(&attn_proj);
let attn_proj = add_bias(&attn_proj, &layer.o_proj_bias);
// Residual + post-norm
let (normed, x_new) = Self::add_norm(&attn_proj, &residual, &layer.post_norm, &layer.post_norm_bias, eps);
let (normed, x_new) = Self::add_norm(
&attn_proj,
&residual,
&layer.post_norm,
&layer.post_norm_bias,
eps,
);
let residual = x_new;
let normed = normed.contiguous();
@@ -505,7 +614,9 @@ impl GptOss {
paged_cache.advance_seq_len(slot, new_tokens);
let mut x = embedding(&self.embed_tokens, token_ids);
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect();
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens)
.map(|p| p as u32)
.collect();
for (layer_idx, layer) in self.layers.iter().enumerate() {
let residual = x.clone();
@@ -532,14 +643,21 @@ impl GptOss {
let (k_full, v_full) = paged_cache.gather_kv_contiguous(slot, layer_idx);
// Flash attention with gpt-oss sinks + (per-layer) sliding window.
let attn_out = flash_attention_sinks(&q, &k_full, &v_full, &layer.sinks, layer.window_size);
let attn_out =
flash_attention_sinks(&q, &k_full, &v_full, &layer.sinks, layer.window_size);
let attn_merged = merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
self.all_reduce(&attn_proj);
let attn_proj = add_bias(&attn_proj, &layer.o_proj_bias);
let (normed, x_new) = Self::add_norm(&attn_proj, &residual, &layer.post_norm, &layer.post_norm_bias, eps);
let (normed, x_new) = Self::add_norm(
&attn_proj,
&residual,
&layer.post_norm,
&layer.post_norm_bias,
eps,
);
let residual = x_new;
// MoE MLP
@@ -566,15 +684,11 @@ impl GptOss {
let expert_start = rank * local_experts;
// 1. Router: [tokens, hidden] @ [hidden, num_experts] + bias → [tokens, num_experts]
let router_logits = add_bias(
&matmul_2d(x, &layer.router_wt),
&layer.router_bias,
);
let router_logits = add_bias(&matmul_2d(x, &layer.router_wt), &layer.router_bias);
// 2. GPU top-k + softmax
let (topk_ids, topk_weights) = xserv_kernels::moe::moe_topk_softmax(
&router_logits, num_experts, top_k,
);
let (topk_ids, topk_weights) =
xserv_kernels::moe::moe_topk_softmax(&router_logits, num_experts, top_k);
// Sparse decode path: compute ONLY the routed experts. The dense path
// below reads every local expert's weights per forward; the sparse
@@ -588,15 +702,31 @@ impl GptOss {
let n = packed.shape()[1];
let k = packed.shape()[2] * 2;
xserv_kernels::moe::moe_sparse_gemv_mxfp4(
x, packed, scales, &layer.expert_gate_up_bias, &topk_ids,
num_tokens, top_k, n, k, expert_start, local_experts, false,
x,
packed,
scales,
&layer.expert_gate_up_bias,
&topk_ids,
num_tokens,
top_k,
n,
k,
expert_start,
local_experts,
false,
)
} else {
xserv_kernels::moe::moe_sparse_gemv_fp8(
x, layer.expert_gate_up_fp8.as_ref().unwrap(),
x,
layer.expert_gate_up_fp8.as_ref().unwrap(),
layer.expert_gate_up_scale.as_ref().unwrap(),
&layer.expert_gate_up_bias, &topk_ids,
num_tokens, top_k, expert_start, local_experts, false,
&layer.expert_gate_up_bias,
&topk_ids,
num_tokens,
top_k,
expert_start,
local_experts,
false,
)
};
@@ -611,20 +741,40 @@ impl GptOss {
let n = packed.shape()[1];
let k = packed.shape()[2] * 2;
xserv_kernels::moe::moe_sparse_gemv_mxfp4(
&activated, packed, scales, &layer.expert_down_bias, &topk_ids,
num_tokens, top_k, n, k, expert_start, local_experts, true,
&activated,
packed,
scales,
&layer.expert_down_bias,
&topk_ids,
num_tokens,
top_k,
n,
k,
expert_start,
local_experts,
true,
)
} else {
xserv_kernels::moe::moe_sparse_gemv_fp8(
&activated, layer.expert_down_fp8.as_ref().unwrap(),
&activated,
layer.expert_down_fp8.as_ref().unwrap(),
layer.expert_down_scale.as_ref().unwrap(),
&layer.expert_down_bias, &topk_ids,
num_tokens, top_k, expert_start, local_experts, true,
&layer.expert_down_bias,
&topk_ids,
num_tokens,
top_k,
expert_start,
local_experts,
true,
)
};
let moe_out = xserv_kernels::moe::moe_weighted_sum_sparse(
&down, &topk_ids, &topk_weights, expert_start, local_experts,
&down,
&topk_ids,
&topk_weights,
expert_start,
local_experts,
);
self.all_reduce(&moe_out);
return moe_out;
@@ -644,14 +794,24 @@ impl GptOss {
xserv_kernels::quantization::batched_gemv_mxfp4(&x2, packed, scales, n, k)
.reshape(&[local_experts, 1, n])
} else {
let w_bf16 = xserv_kernels::quantization::dequant_mxfp4_to_bf16_t(packed, scales, local_experts, n, k);
let w_bf16 = xserv_kernels::quantization::dequant_mxfp4_to_bf16_t(
packed,
scales,
local_experts,
n,
k,
);
xserv_kernels::moe::batched_gemm_strided(&x_rep, &w_bf16)
}
} else if let Some(ref wt_fp8_t) = layer.expert_gate_up_fp8 {
// W8A8: quantize activations with per-expert scalar scale, use cuBLASLt FP8 GEMM
let (x_fp8, x_scales) = xserv_kernels::quantization::quantize_bf16_to_fp8_rowwise(&x_rep);
let (x_fp8, x_scales) =
xserv_kernels::quantization::quantize_bf16_to_fp8_rowwise(&x_rep);
xserv_kernels::quantization::batched_gemm_fp8(
&x_fp8, &x_scales, wt_fp8_t, layer.expert_gate_up_scale.as_ref().unwrap(),
&x_fp8,
&x_scales,
wt_fp8_t,
layer.expert_gate_up_scale.as_ref().unwrap(),
)
} else {
xserv_kernels::moe::batched_gemm_strided(&x_rep, &layer.expert_gate_up_wt)
@@ -677,14 +837,24 @@ impl GptOss {
xserv_kernels::quantization::batched_gemv_mxfp4(&a2, packed, scales, n, k)
.reshape(&[local_experts, 1, n])
} else {
let w_bf16 = xserv_kernels::quantization::dequant_mxfp4_to_bf16_t(packed, scales, local_experts, n, k);
let w_bf16 = xserv_kernels::quantization::dequant_mxfp4_to_bf16_t(
packed,
scales,
local_experts,
n,
k,
);
xserv_kernels::moe::batched_gemm_strided(&activated, &w_bf16)
}
} else if let Some(ref wt_fp8) = layer.expert_down_fp8 {
// W8A8: quantize post-GLU activations to FP8, use cuBLASLt FP8 GEMM
let (act_fp8, act_scales) = xserv_kernels::quantization::quantize_bf16_to_fp8_rowwise(&activated);
let (act_fp8, act_scales) =
xserv_kernels::quantization::quantize_bf16_to_fp8_rowwise(&activated);
xserv_kernels::quantization::batched_gemm_fp8(
&act_fp8, &act_scales, wt_fp8, layer.expert_down_scale.as_ref().unwrap(),
&act_fp8,
&act_scales,
wt_fp8,
layer.expert_down_scale.as_ref().unwrap(),
)
} else {
xserv_kernels::moe::batched_gemm_strided(&activated, &layer.expert_down_wt)
@@ -695,8 +865,12 @@ impl GptOss {
// 9. Weighted sum across experts → [tokens, hidden]
let moe_out = xserv_kernels::moe::moe_weighted_sum(
&down, &topk_ids, &topk_weights,
expert_start, local_experts, top_k,
&down,
&topk_ids,
&topk_weights,
expert_start,
local_experts,
top_k,
);
self.all_reduce(&moe_out);
@@ -708,9 +882,7 @@ impl GptOss {
/// Upload a u32 slice to a pooled GPU buffer (synchronous H2D).
fn upload_u32(vals: &[u32]) -> xserv_cuda::GpuBuffer {
let bytes = unsafe {
std::slice::from_raw_parts(vals.as_ptr() as *const u8, vals.len() * 4)
};
let bytes = unsafe { std::slice::from_raw_parts(vals.as_ptr() as *const u8, vals.len() * 4) };
let mut buf = xserv_cuda::allocator::cached_alloc(bytes.len()).expect("alloc u32 upload");
buf.copy_from_host(bytes).unwrap();
buf
@@ -737,11 +909,16 @@ fn add_bias(x: &Tensor, bias: &Tensor) -> Tensor {
}
fn shard_rows(t: &Tensor, rank: usize, world: usize) -> Tensor {
if world == 1 { return t.clone(); }
if world == 1 {
return t.clone();
}
let shape = t.shape();
assert_eq!(shape.len(), 2);
let (rows, cols) = (shape[0], shape[1]);
assert!(rows % world == 0, "rows {rows} not divisible by world {world}");
assert!(
rows % world == 0,
"rows {rows} not divisible by world {world}"
);
let local = rows / world;
let host = t.to_device(Device::Cpu);
let data = host.as_slice::<bf16>();
@@ -751,11 +928,16 @@ fn shard_rows(t: &Tensor, rank: usize, world: usize) -> Tensor {
}
fn shard_cols(t: &Tensor, rank: usize, world: usize) -> Tensor {
if world == 1 { return t.clone(); }
if world == 1 {
return t.clone();
}
let shape = t.shape();
assert_eq!(shape.len(), 2);
let (rows, cols) = (shape[0], shape[1]);
assert!(cols % world == 0, "cols {cols} not divisible by world {world}");
assert!(
cols % world == 0,
"cols {cols} not divisible by world {world}"
);
let local = cols / world;
let c0 = rank * local;
let host = t.to_device(Device::Cpu);
@@ -769,11 +951,16 @@ fn shard_cols(t: &Tensor, rank: usize, world: usize) -> Tensor {
}
fn shard_1d(t: &Tensor, rank: usize, world: usize) -> Tensor {
if world == 1 { return t.clone(); }
if world == 1 {
return t.clone();
}
let shape = t.shape();
assert_eq!(shape.len(), 1);
let total = shape[0];
assert!(total % world == 0, "dim {total} not divisible by world {world}");
assert!(
total % world == 0,
"dim {total} not divisible by world {world}"
);
let local = total / world;
let host = t.to_device(Device::Cpu);
let data = host.as_slice::<bf16>();
@@ -804,7 +991,13 @@ fn transpose_3d_inner_raw(t: &Tensor, batch: usize, rows: usize, cols: usize) ->
}
/// Extract experts [start..start+count) from a [num_experts, rows, cols] 3D tensor (any dtype, raw bytes).
fn slice_expert_range_3d_raw(t: &Tensor, start: usize, count: usize, rows: usize, cols: usize) -> Tensor {
fn slice_expert_range_3d_raw(
t: &Tensor,
start: usize,
count: usize,
rows: usize,
cols: usize,
) -> Tensor {
assert_eq!(t.ndim(), 3);
let host = t.to_device(Device::Cpu);
let elem_size = t.dtype().size_bytes();
@@ -826,7 +1019,13 @@ fn slice_scale_range(t: &Tensor, start: usize, count: usize) -> Tensor {
}
/// Extract experts [start..start+count) from a [num_experts, rows, cols] 3D tensor
fn slice_expert_range_3d(t: &Tensor, start: usize, count: usize, rows: usize, cols: usize) -> Tensor {
fn slice_expert_range_3d(
t: &Tensor,
start: usize,
count: usize,
rows: usize,
cols: usize,
) -> Tensor {
assert_eq!(t.ndim(), 3);
let host = t.to_device(Device::Cpu);
let data = host.as_slice::<bf16>();

View File

@@ -59,7 +59,9 @@ impl GptOssDecodeGraph {
model.decode_prepare(&[position], &[slot], cache);
ids_buf.copy_from_host(&token.to_le_bytes()).unwrap();
pos_buf.copy_from_host(&(position as u32).to_le_bytes()).unwrap();
pos_buf
.copy_from_host(&(position as u32).to_le_bytes())
.unwrap();
// Retained warmup: run the exact step once eagerly with the quarantine
// ON. Freed intermediates are held back instead of recycled, so the
@@ -88,21 +90,32 @@ impl GptOssDecodeGraph {
let logits;
{
let _guard = xserv_cuda::stream::push_stream(&stream);
graph.begin_capture(&stream).expect("begin decode-graph capture");
graph
.begin_capture(&stream)
.expect("begin decode-graph capture");
logits = model.decode_core(
ids_buf.as_ptr() as *const c_void,
pos_buf.as_ptr() as *const c_void,
1,
cache,
);
graph.end_capture(&stream).expect("end decode-graph capture");
graph
.end_capture(&stream)
.expect("end decode-graph capture");
}
let arena = allocator::end_retain();
graph.launch(&stream).expect("first decode-graph replay");
cache.advance_seq_len(slot, 1);
Self { stream, graph, ids_buf, pos_buf, logits, _arena: arena }
Self {
stream,
graph,
ids_buf,
pos_buf,
logits,
_arena: arena,
}
}
/// Run one decode step by replaying the captured graph.
@@ -116,8 +129,12 @@ impl GptOssDecodeGraph {
) -> Tensor {
model.decode_prepare(&[position], &[slot], cache);
self.ids_buf.copy_from_host(&token.to_le_bytes()).unwrap();
self.pos_buf.copy_from_host(&(position as u32).to_le_bytes()).unwrap();
self.graph.launch(&self.stream).expect("decode-graph replay");
self.pos_buf
.copy_from_host(&(position as u32).to_le_bytes())
.unwrap();
self.graph
.launch(&self.stream)
.expect("decode-graph replay");
cache.advance_seq_len(slot, 1);
// Shallow clone: the caller reads these logits before the next replay
// rewrites the underlying buffer.
@@ -137,8 +154,14 @@ pub struct GraphedGptOssDecoder {
impl GraphedGptOssDecoder {
pub fn new() -> Self {
let enabled = std::env::var("XSERV_DECODE_GRAPH").map(|v| v != "0").unwrap_or(true);
Self { graph: None, eager_steps: 0, enabled }
let enabled = std::env::var("XSERV_DECODE_GRAPH")
.map(|v| v != "0")
.unwrap_or(true);
Self {
graph: None,
eager_steps: 0,
enabled,
}
}
pub fn decode(

View File

@@ -1,6 +1,6 @@
use crate::config::ModelConfig;
use xserv_cuda::GpuBuffer;
use xserv_tensor::{DType, Tensor};
use crate::config::ModelConfig;
/// GPU-resident KV cache. Pre-allocates max_seq_len on GPU,
/// appends new K/V via D2D copy at offset (no CPU round-trip).
@@ -46,17 +46,43 @@ impl GpuKVCache {
v_staging.push(GpuBuffer::alloc(buf_size).expect("alloc KV staging V"));
}
Self { k_bufs, v_bufs, k_staging, v_staging, seq_len: 0, max_seq_len, num_kv_heads, head_dim, elem_size, dtype, device }
Self {
k_bufs,
v_bufs,
k_staging,
v_staging,
seq_len: 0,
max_seq_len,
num_kv_heads,
head_dim,
elem_size,
dtype,
device,
}
}
pub fn seq_len(&self) -> usize { self.seq_len }
pub fn max_seq_len(&self) -> usize { self.max_seq_len }
pub fn seq_len(&self) -> usize {
self.seq_len
}
pub fn max_seq_len(&self) -> usize {
self.max_seq_len
}
/// Append new K/V tensors for a given layer.
/// k_new, v_new: [1, num_kv_heads, new_tokens, head_dim] on GPU, contiguous.
/// `write_pos` is the sequence position to write at (caller manages this).
pub fn append(&mut self, layer: usize, k_new: &Tensor, v_new: &Tensor, new_tokens: usize, write_pos: usize) {
assert!(write_pos + new_tokens <= self.max_seq_len, "KV cache overflow");
pub fn append(
&mut self,
layer: usize,
k_new: &Tensor,
v_new: &Tensor,
new_tokens: usize,
write_pos: usize,
) {
assert!(
write_pos + new_tokens <= self.max_seq_len,
"KV cache overflow"
);
let es = self.elem_size;
let hd = self.head_dim;
let max_s = self.max_seq_len;
@@ -69,14 +95,23 @@ impl GpuKVCache {
let src_off = h * new_tokens * hd * es;
let dst_off = (h * max_s + write_pos) * hd * es;
let count = new_tokens * hd * es;
self.k_bufs[layer].copy_from_device_at(k_src, src_off, dst_off, count).unwrap();
self.v_bufs[layer].copy_from_device_at(v_src, src_off, dst_off, count).unwrap();
self.k_bufs[layer]
.copy_from_device_at(k_src, src_off, dst_off, count)
.unwrap();
self.v_bufs[layer]
.copy_from_device_at(v_src, src_off, dst_off, count)
.unwrap();
}
}
pub fn advance_seq_len(&mut self, new_tokens: usize) {
self.seq_len += new_tokens;
assert!(self.seq_len <= self.max_seq_len, "KV cache seq_len ({}) exceeds max_seq_len ({})", self.seq_len, self.max_seq_len);
assert!(
self.seq_len <= self.max_seq_len,
"KV cache seq_len ({}) exceeds max_seq_len ({})",
self.seq_len,
self.max_seq_len
);
}
/// Get K/V cache tensors for a layer up to `seq_len` tokens: [1, num_kv_heads, seq_len, head_dim]
@@ -86,7 +121,11 @@ impl GpuKVCache {
}
pub fn get_kv_len(&mut self, layer: usize, sl: usize) -> (Tensor, Tensor) {
assert!(sl <= self.max_seq_len, "get_kv_len: sl ({sl}) exceeds max_seq_len ({})", self.max_seq_len);
assert!(
sl <= self.max_seq_len,
"get_kv_len: sl ({sl}) exceeds max_seq_len ({})",
self.max_seq_len
);
let hd = self.head_dim;
let nh = self.num_kv_heads;
let es = self.elem_size;
@@ -104,8 +143,12 @@ impl GpuKVCache {
let src_off = (h * max_s) * hd * es;
let dst_off = (h * sl) * hd * es;
let count = sl * hd * es;
k_stg.copy_from_device_at(k_buf, src_off, dst_off, count).unwrap();
v_stg.copy_from_device_at(v_buf, src_off, dst_off, count).unwrap();
k_stg
.copy_from_device_at(k_buf, src_off, dst_off, count)
.unwrap();
v_stg
.copy_from_device_at(v_buf, src_off, dst_off, count)
.unwrap();
}
// Grab raw pointers before dropping the mutable borrows
let k_ptr = k_stg.as_mut_ptr();
@@ -117,20 +160,35 @@ impl GpuKVCache {
// get_kv_len call overwrites the staging buffer).
let shape = &[1usize, nh, sl, hd];
let k = unsafe {
tensor_from_gpu_buffer(GpuBuffer::borrow_raw(k_ptr, out_size), shape, self.dtype, self.device)
tensor_from_gpu_buffer(
GpuBuffer::borrow_raw(k_ptr, out_size),
shape,
self.dtype,
self.device,
)
};
let v = unsafe {
tensor_from_gpu_buffer(GpuBuffer::borrow_raw(v_ptr, out_size), shape, self.dtype, self.device)
tensor_from_gpu_buffer(
GpuBuffer::borrow_raw(v_ptr, out_size),
shape,
self.dtype,
self.device,
)
};
(k, v)
}
}
/// Create a Tensor from a GpuBuffer (takes ownership).
unsafe fn tensor_from_gpu_buffer(buf: GpuBuffer, shape: &[usize], dtype: DType, device: u32) -> Tensor {
use xserv_tensor::storage::Storage;
use xserv_tensor::shape::contiguous_strides;
unsafe fn tensor_from_gpu_buffer(
buf: GpuBuffer,
shape: &[usize],
dtype: DType,
device: u32,
) -> Tensor {
use smallvec::SmallVec;
use xserv_tensor::shape::contiguous_strides;
use xserv_tensor::storage::Storage;
let storage = Storage::cuda(buf, device);
Tensor::from_storage(
@@ -146,6 +204,11 @@ unsafe fn tensor_from_gpu_buffer(buf: GpuBuffer, shape: &[usize], dtype: DType,
///
/// # Safety
/// `buf` must be a valid GPU allocation with at least `product(shape) * dtype.size_bytes()` bytes.
pub unsafe fn tensor_from_gpu_buffer_pub(buf: GpuBuffer, shape: &[usize], dtype: DType, device: u32) -> Tensor {
pub unsafe fn tensor_from_gpu_buffer_pub(
buf: GpuBuffer,
shape: &[usize],
dtype: DType,
device: u32,
) -> Tensor {
tensor_from_gpu_buffer(buf, shape, dtype, device)
}

View File

@@ -11,11 +11,11 @@ pub mod sampling;
pub use config::ModelConfig;
pub use decode_graph::{DecodeGraphState, LayerWeightPtrs};
pub use gpt2::{GPT2, KVCache};
pub use gpt_oss::GptOss;
pub use gpt_oss_graph::{GptOssDecodeGraph, GraphedGptOssDecoder};
pub use gpt2::{GPT2, KVCache};
pub use kv_cache::GpuKVCache;
pub use paged_kv_cache::{BlockAllocator, Location, PagedKVCache, BLOCK_SIZE};
pub use paged_kv_cache::{BLOCK_SIZE, BlockAllocator, Location, PagedKVCache};
pub use qwen3::Qwen3;
pub use sampling::{SamplingParams, sample, sample_greedy_penalized};

View File

@@ -5,8 +5,8 @@ use std::path::Path;
use xserv_tensor::{DType, Device, Tensor};
pub fn load_safetensors(path: &Path, device: Device) -> HashMap<String, Tensor> {
let data = std::fs::read(path)
.unwrap_or_else(|e| panic!("failed to read {}: {e}", path.display()));
let data =
std::fs::read(path).unwrap_or_else(|e| panic!("failed to read {}: {e}", path.display()));
let st = SafeTensors::deserialize(&data)
.unwrap_or_else(|e| panic!("failed to parse safetensors {}: {e}", path.display()));
@@ -60,7 +60,11 @@ pub fn load_model_dir(dir: &Path, device: Device) -> HashMap<String, Tensor> {
all_tensors.extend(tensors);
}
assert!(!all_tensors.is_empty(), "no safetensors files found in {}", dir.display());
assert!(
!all_tensors.is_empty(),
"no safetensors files found in {}",
dir.display()
);
all_tensors
}
@@ -84,8 +88,6 @@ fn make_tensor(raw_bytes: &[u8], shape: &[usize], dtype: DType) -> Tensor {
};
Tensor::from_slice(bfs, shape)
}
DType::FP8E4M3 => {
Tensor::from_raw_bytes(raw_bytes, shape, DType::FP8E4M3)
}
DType::FP8E4M3 => Tensor::from_raw_bytes(raw_bytes, shape, DType::FP8E4M3),
}
}

View File

@@ -29,7 +29,10 @@ impl BlockAllocator {
for b in (1..total_blocks).rev() {
free_stack.push(b as u32);
}
Self { free_stack, total: total_blocks }
Self {
free_stack,
total: total_blocks,
}
}
pub fn alloc(&mut self) -> Option<u32> {
@@ -136,8 +139,14 @@ impl PagedKVCache {
device: u32,
) -> Self {
Self::new_tp(
config, config.num_kv_heads(), total_blocks, cpu_total_blocks,
max_seqs, max_blocks_per_seq, dtype, device,
config,
config.num_kv_heads(),
total_blocks,
cpu_total_blocks,
max_seqs,
max_blocks_per_seq,
dtype,
device,
)
}
@@ -155,7 +164,10 @@ impl PagedKVCache {
dtype: DType,
device: u32,
) -> Self {
assert!(total_blocks >= 2, "need at least 2 blocks (one is sentinel)");
assert!(
total_blocks >= 2,
"need at least 2 blocks (one is sentinel)"
);
let num_layers = config.num_layers();
let head_dim = config.head_dim();
let elem_size = dtype.size_bytes();
@@ -179,11 +191,17 @@ impl PagedKVCache {
if cpu_total_blocks >= 2 {
let cpu_pool_bytes = cpu_total_blocks * block_bytes;
for _ in 0..num_layers {
cpu_k_pools.push(PinnedBuffer::alloc(cpu_pool_bytes).expect("alloc CPU K swap pool"));
cpu_v_pools.push(PinnedBuffer::alloc(cpu_pool_bytes).expect("alloc CPU V swap pool"));
cpu_k_pools
.push(PinnedBuffer::alloc(cpu_pool_bytes).expect("alloc CPU K swap pool"));
cpu_v_pools
.push(PinnedBuffer::alloc(cpu_pool_bytes).expect("alloc CPU V swap pool"));
}
}
let cpu_allocator = BlockAllocator::new(if cpu_total_blocks >= 2 { cpu_total_blocks } else { 0 });
let cpu_allocator = BlockAllocator::new(if cpu_total_blocks >= 2 {
cpu_total_blocks
} else {
0
});
let block_table_gpu =
GpuBuffer::alloc(max_seqs * max_blocks_per_seq * std::mem::size_of::<i32>())
@@ -220,22 +238,49 @@ impl PagedKVCache {
}
}
pub fn num_layers(&self) -> usize { self.num_layers }
pub fn num_kv_heads(&self) -> usize { self.num_kv_heads }
pub fn head_dim(&self) -> usize { self.head_dim }
pub fn dtype(&self) -> DType { self.dtype }
pub fn max_seqs(&self) -> usize { self.max_seqs }
pub fn max_blocks_per_seq(&self) -> usize { self.max_blocks_per_seq }
pub fn free_blocks(&self) -> usize { self.allocator.free_count() }
pub fn total_blocks(&self) -> usize { self.allocator.total() }
pub fn num_layers(&self) -> usize {
self.num_layers
}
pub fn num_kv_heads(&self) -> usize {
self.num_kv_heads
}
pub fn head_dim(&self) -> usize {
self.head_dim
}
pub fn dtype(&self) -> DType {
self.dtype
}
pub fn max_seqs(&self) -> usize {
self.max_seqs
}
pub fn max_blocks_per_seq(&self) -> usize {
self.max_blocks_per_seq
}
pub fn free_blocks(&self) -> usize {
self.allocator.free_count()
}
pub fn total_blocks(&self) -> usize {
self.allocator.total()
}
pub fn k_pool(&self, layer: usize) -> &GpuBuffer { &self.k_pools[layer] }
pub fn v_pool(&self, layer: usize) -> &GpuBuffer { &self.v_pools[layer] }
pub fn block_table_gpu(&self) -> &GpuBuffer { &self.block_table_gpu }
pub fn context_lens_gpu(&self) -> &GpuBuffer { &self.context_lens_gpu }
pub fn k_pool(&self, layer: usize) -> &GpuBuffer {
&self.k_pools[layer]
}
pub fn v_pool(&self, layer: usize) -> &GpuBuffer {
&self.v_pools[layer]
}
pub fn block_table_gpu(&self) -> &GpuBuffer {
&self.block_table_gpu
}
pub fn context_lens_gpu(&self) -> &GpuBuffer {
&self.context_lens_gpu
}
pub fn seq_len(&self, slot: usize) -> usize {
self.seq_states[slot].as_ref().map(|s| s.seq_len).unwrap_or(0)
self.seq_states[slot]
.as_ref()
.map(|s| s.seq_len)
.unwrap_or(0)
}
pub fn is_slot_free(&self, slot: usize) -> bool {
@@ -280,7 +325,11 @@ impl PagedKVCache {
let state = self.seq_states[slot].as_ref().expect("unregistered slot");
let cur = state.block_ids.len();
let needed_total = (state.seq_len + new_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE;
if needed_total > cur { needed_total - cur } else { 0 }
if needed_total > cur {
needed_total - cur
} else {
0
}
}
/// Pre-allocate enough physical blocks in `slot` to cover positions
@@ -290,8 +339,14 @@ impl PagedKVCache {
let state = self.seq_states[slot].as_mut().expect("unregistered slot");
let needed_total = (end_pos + BLOCK_SIZE - 1) / BLOCK_SIZE;
while state.block_ids.len() < needed_total {
let b = self.allocator.alloc().expect("out of blocks (caller must check)");
assert!(state.block_ids.len() < self.max_blocks_per_seq, "block table overflow");
let b = self
.allocator
.alloc()
.expect("out of blocks (caller must check)");
assert!(
state.block_ids.len() < self.max_blocks_per_seq,
"block table overflow"
);
state.block_ids.push(b);
}
}
@@ -318,7 +373,9 @@ impl PagedKVCache {
num_tokens: usize,
start_pos: usize,
) {
if num_tokens == 0 { return; }
if num_tokens == 0 {
return;
}
// Make sure blocks exist for the target range.
self.ensure_capacity(slot, start_pos + num_tokens);
@@ -328,15 +385,21 @@ impl PagedKVCache {
// Stage block_ids on the GPU. Pool-allocated so this is essentially
// free after the first call (same bucket every step).
let block_ids: Vec<i32> = self.seq_states[slot].as_ref().unwrap()
.block_ids.iter().map(|&b| b as i32).collect();
let block_ids: Vec<i32> = self.seq_states[slot]
.as_ref()
.unwrap()
.block_ids
.iter()
.map(|&b| b as i32)
.collect();
let bytes = block_ids.len() * std::mem::size_of::<i32>();
let mut block_ids_gpu = xserv_cuda::allocator::cached_alloc(bytes)
.expect("alloc append block_ids");
let block_ids_bytes = unsafe {
std::slice::from_raw_parts(block_ids.as_ptr() as *const u8, bytes)
};
block_ids_gpu.copy_from_host(block_ids_bytes).expect("upload block_ids");
let mut block_ids_gpu =
xserv_cuda::allocator::cached_alloc(bytes).expect("alloc append block_ids");
let block_ids_bytes =
unsafe { std::slice::from_raw_parts(block_ids.as_ptr() as *const u8, bytes) };
block_ids_gpu
.copy_from_host(block_ids_bytes)
.expect("upload block_ids");
let k_src = k_new.data_ptr() as *const std::ffi::c_void;
let v_src = v_new.data_ptr() as *const std::ffi::c_void;
@@ -345,10 +408,16 @@ impl PagedKVCache {
unsafe {
xserv_kernels::reshape_and_cache_bf16(
k_src, v_src,
k_pool_ptr, v_pool_ptr,
k_src,
v_src,
k_pool_ptr,
v_pool_ptr,
block_ids_gpu.as_ptr() as *const i32,
num_tokens, nkv, hd, start_pos, bs,
num_tokens,
nkv,
hd,
start_pos,
bs,
xserv_cuda::current_stream_raw(),
);
}
@@ -378,7 +447,9 @@ impl PagedKVCache {
v_new: &Tensor,
batch: usize,
) {
if batch == 0 { return; }
if batch == 0 {
return;
}
let nkv = self.num_kv_heads;
let hd = self.head_dim;
debug_assert_eq!(k_new.shape(), &[batch, nkv, hd]);
@@ -393,10 +464,17 @@ impl PagedKVCache {
unsafe {
xserv_kernels::reshape_and_cache_batched_bf16(
k_src, v_src,
k_pool_ptr, v_pool_ptr,
bt_ptr, cl_ptr,
batch, nkv, hd, BLOCK_SIZE, self.max_blocks_per_seq,
k_src,
v_src,
k_pool_ptr,
v_pool_ptr,
bt_ptr,
cl_ptr,
batch,
nkv,
hd,
BLOCK_SIZE,
self.max_blocks_per_seq,
xserv_cuda::current_stream_raw(),
);
}
@@ -447,7 +525,10 @@ impl PagedKVCache {
/// before advance_seq_len has run).
pub fn sync_active_batch_with_lens(&mut self, slots: &[usize], kv_lens: &[i32]) {
assert_eq!(slots.len(), kv_lens.len());
assert!(slots.len() <= self.max_seqs, "active batch exceeds max_seqs");
assert!(
slots.len() <= self.max_seqs,
"active batch exceeds max_seqs"
);
let stride = self.max_blocks_per_seq;
for row in &mut self.block_table_host {
*row = 0;
@@ -456,7 +537,9 @@ impl PagedKVCache {
*cl = 0;
}
for (i, &slot) in slots.iter().enumerate() {
let s = self.seq_states[slot].as_ref().expect("unregistered slot in active batch");
let s = self.seq_states[slot]
.as_ref()
.expect("unregistered slot in active batch");
let row = &mut self.block_table_host[i * stride..(i + 1) * stride];
for (j, b) in s.block_ids.iter().enumerate() {
row[j] = *b as i32;
@@ -515,8 +598,12 @@ impl PagedKVCache {
let src_off = ((phys * nkv + h) * bs + slot_in_blk) * hd * es;
let dst_off = (h * sl + p) * hd * es;
let count = chunk * hd * es;
k_dst.copy_from_device_at(k_pool, src_off, dst_off, count).unwrap();
v_dst.copy_from_device_at(v_pool, src_off, dst_off, count).unwrap();
k_dst
.copy_from_device_at(k_pool, src_off, dst_off, count)
.unwrap();
v_dst
.copy_from_device_at(v_pool, src_off, dst_off, count)
.unwrap();
}
p += chunk;
}
@@ -529,16 +616,26 @@ impl PagedKVCache {
// ----- Swapping (vLLM-style preemption to pinned host memory) -----
pub fn free_cpu_blocks(&self) -> usize { self.cpu_allocator.free_count() }
pub fn swap_enabled(&self) -> bool { !self.cpu_k_pools.is_empty() }
pub fn free_cpu_blocks(&self) -> usize {
self.cpu_allocator.free_count()
}
pub fn swap_enabled(&self) -> bool {
!self.cpu_k_pools.is_empty()
}
pub fn is_swapped(&self, slot: usize) -> bool {
matches!(self.seq_states[slot].as_ref().map(|s| s.location), Some(Location::Cpu))
matches!(
self.seq_states[slot].as_ref().map(|s| s.location),
Some(Location::Cpu)
)
}
/// Number of physical blocks currently held by `slot` (in either pool).
pub fn block_count(&self, slot: usize) -> usize {
self.seq_states[slot].as_ref().map(|s| s.block_ids.len()).unwrap_or(0)
self.seq_states[slot]
.as_ref()
.map(|s| s.block_ids.len())
.unwrap_or(0)
}
/// Whether a swapped sequence at `slot` can be brought back (enough free GPU blocks).
@@ -554,11 +651,17 @@ impl PagedKVCache {
/// Evict `slot`'s KV from GPU to pinned host memory and free its GPU blocks.
/// The slot stays registered (location = Cpu); the sequence is paused.
pub fn swap_out(&mut self, slot: usize) -> Result<(), &'static str> {
let state = self.seq_states[slot].as_ref().ok_or("swap_out: empty slot")?;
if state.location == Location::Cpu { return Ok(()); }
let state = self.seq_states[slot]
.as_ref()
.ok_or("swap_out: empty slot")?;
if state.location == Location::Cpu {
return Ok(());
}
let gpu_ids = state.block_ids.clone();
let n = gpu_ids.len();
if !self.cpu_allocator.can_alloc(n) { return Err("swap_out: CPU pool full"); }
if !self.cpu_allocator.can_alloc(n) {
return Err("swap_out: CPU pool full");
}
let cpu_ids: Vec<u32> = (0..n)
.map(|_| self.cpu_allocator.alloc().expect("checked can_alloc"))
@@ -570,10 +673,18 @@ impl PagedKVCache {
let g_off = gpu_ids[i] as usize * bb;
let c_off = cpu_ids[i] as usize * bb;
self.k_pools[layer]
.copy_to_host_at(&mut self.cpu_k_pools[layer].as_mut_slice()[c_off..c_off + bb], g_off, bb)
.copy_to_host_at(
&mut self.cpu_k_pools[layer].as_mut_slice()[c_off..c_off + bb],
g_off,
bb,
)
.unwrap();
self.v_pools[layer]
.copy_to_host_at(&mut self.cpu_v_pools[layer].as_mut_slice()[c_off..c_off + bb], g_off, bb)
.copy_to_host_at(
&mut self.cpu_v_pools[layer].as_mut_slice()[c_off..c_off + bb],
g_off,
bb,
)
.unwrap();
}
}
@@ -589,11 +700,17 @@ impl PagedKVCache {
/// Bring `slot`'s KV back from host to GPU and free its CPU blocks.
pub fn swap_in(&mut self, slot: usize) -> Result<(), &'static str> {
let state = self.seq_states[slot].as_ref().ok_or("swap_in: empty slot")?;
if state.location == Location::Gpu { return Ok(()); }
let state = self.seq_states[slot]
.as_ref()
.ok_or("swap_in: empty slot")?;
if state.location == Location::Gpu {
return Ok(());
}
let cpu_ids = state.block_ids.clone();
let n = cpu_ids.len();
if !self.allocator.can_alloc(n) { return Err("swap_in: GPU pool full"); }
if !self.allocator.can_alloc(n) {
return Err("swap_in: GPU pool full");
}
let gpu_ids: Vec<u32> = (0..n)
.map(|_| self.allocator.alloc().expect("checked can_alloc"))
@@ -605,10 +722,18 @@ impl PagedKVCache {
let g_off = gpu_ids[i] as usize * bb;
let c_off = cpu_ids[i] as usize * bb;
self.k_pools[layer]
.copy_from_host_at(&self.cpu_k_pools[layer].as_slice()[c_off..c_off + bb], g_off, bb)
.copy_from_host_at(
&self.cpu_k_pools[layer].as_slice()[c_off..c_off + bb],
g_off,
bb,
)
.unwrap();
self.v_pools[layer]
.copy_from_host_at(&self.cpu_v_pools[layer].as_slice()[c_off..c_off + bb], g_off, bb)
.copy_from_host_at(
&self.cpu_v_pools[layer].as_slice()[c_off..c_off + bb],
g_off,
bb,
)
.unwrap();
}
}
@@ -623,7 +748,12 @@ impl PagedKVCache {
}
}
unsafe fn tensor_from_owned_buf(buf: GpuBuffer, shape: &[usize], dtype: DType, device: u32) -> Tensor {
unsafe fn tensor_from_owned_buf(
buf: GpuBuffer,
shape: &[usize],
dtype: DType,
device: u32,
) -> Tensor {
use smallvec::SmallVec;
use xserv_tensor::shape::contiguous_strides;
use xserv_tensor::storage::Storage;

View File

@@ -1,5 +1,5 @@
use std::collections::HashMap;
use half::bf16;
use std::collections::HashMap;
use xserv_kernels::*;
use xserv_tensor::{Device, Tensor};
@@ -13,7 +13,7 @@ pub struct Qwen3 {
embed_tokens: Tensor,
layers: Vec<Qwen3Block>,
norm: Tensor,
lm_head_t: Tensor, // precomputed transpose
lm_head_t: Tensor, // precomputed transpose
rope_cache: RopeCache,
// Tensor parallelism. `tp` is None (or world==1) for single-GPU; otherwise
// this rank holds 1/world of the heads and AllReduces after o_proj/down_proj.
@@ -28,22 +28,29 @@ pub struct Qwen3 {
}
struct Qwen3Block {
input_norm: Tensor, // [hidden]
input_norm: Tensor, // [hidden]
qkv_proj_wt: Tensor, // FUSED: [hidden, (H+2*KV)*D] — Q|K|V columns
q_dim: usize, // num_heads * head_dim (Q slice boundary)
kv_dim: usize, // num_kv_heads * head_dim (K/V slice size)
o_proj_wt: Tensor, // TRANSPOSED: [num_heads*head_dim, hidden]
q_norm: Tensor, // [head_dim]
k_norm: Tensor, // [head_dim]
post_norm: Tensor, // [hidden]
gate_up_proj_wt: Tensor, // FUSED: [hidden, 2*intermediate]
down_proj_wt: Tensor, // TRANSPOSED: [intermediate, hidden]
o_proj_wt: Tensor, // TRANSPOSED: [num_heads*head_dim, hidden]
q_norm: Tensor, // [head_dim]
k_norm: Tensor, // [head_dim]
post_norm: Tensor, // [hidden]
gate_up_proj_wt: Tensor, // FUSED: [hidden, 2*intermediate]
down_proj_wt: Tensor, // TRANSPOSED: [intermediate, hidden]
}
impl Qwen3Block {
fn q_proj_wt(&self) -> Tensor { self.qkv_proj_wt.narrow(1, 0, self.q_dim) }
fn k_proj_wt(&self) -> Tensor { self.qkv_proj_wt.narrow(1, self.q_dim, self.kv_dim) }
fn v_proj_wt(&self) -> Tensor { self.qkv_proj_wt.narrow(1, self.q_dim + self.kv_dim, self.kv_dim) }
fn q_proj_wt(&self) -> Tensor {
self.qkv_proj_wt.narrow(1, 0, self.q_dim)
}
fn k_proj_wt(&self) -> Tensor {
self.qkv_proj_wt.narrow(1, self.q_dim, self.kv_dim)
}
fn v_proj_wt(&self) -> Tensor {
self.qkv_proj_wt
.narrow(1, self.q_dim + self.kv_dim, self.kv_dim)
}
fn gate_proj_wt(&self) -> Tensor {
let half = self.gate_up_proj_wt.shape()[1] / 2;
self.gate_up_proj_wt.narrow(1, 0, half)
@@ -80,18 +87,31 @@ impl Qwen3 {
crate::init_kernels();
let dev = Device::Cuda(device);
let take = |w: &mut HashMap<String, Tensor>, name: &str| -> Tensor {
w.remove(name).unwrap_or_else(|| panic!("missing weight: {name}"))
w.remove(name)
.unwrap_or_else(|| panic!("missing weight: {name}"))
};
// Replicated weight: upload whole to this rank's device.
let repl = |t: Tensor| -> Tensor { t.to_device(dev) };
// column-parallel: keep this rank's rows of [out, in], upload, transpose → [in, out/world].
let col = |t: Tensor| -> Tensor { shard_rows(&t, rank, world).to_device(dev).transpose(0, 1).contiguous() };
let col = |t: Tensor| -> Tensor {
shard_rows(&t, rank, world)
.to_device(dev)
.transpose(0, 1)
.contiguous()
};
// row-parallel: keep this rank's cols of [out, in], upload, transpose → [in/world, out].
let row = |t: Tensor| -> Tensor { shard_cols(&t, rank, world).to_device(dev).transpose(0, 1).contiguous() };
let row = |t: Tensor| -> Tensor {
shard_cols(&t, rank, world)
.to_device(dev)
.transpose(0, 1)
.contiguous()
};
let embed_tokens = repl(take(&mut w, "model.embed_tokens.weight"));
let norm = repl(take(&mut w, "model.norm.weight"));
let lm_head_t = repl(take(&mut w, "lm_head.weight")).transpose(0, 1).contiguous();
let lm_head_t = repl(take(&mut w, "lm_head.weight"))
.transpose(0, 1)
.contiguous();
let rope_cache = RopeCache::new(
config.max_seq_len(),
@@ -102,7 +122,10 @@ impl Qwen3 {
let num_layers = config.num_layers();
let mut layers = Vec::with_capacity(num_layers);
if rank == 0 {
eprintln!("Loading+sharding weights for {} layers (world={world})...", num_layers);
eprintln!(
"Loading+sharding weights for {} layers (world={world})...",
num_layers
);
}
for i in 0..num_layers {
let p = format!("model.layers.{i}");
@@ -126,7 +149,10 @@ impl Qwen3 {
o_proj_wt: row(take(&mut w, &format!("{p}.self_attn.o_proj.weight"))),
q_norm: repl(take(&mut w, &format!("{p}.self_attn.q_norm.weight"))),
k_norm: repl(take(&mut w, &format!("{p}.self_attn.k_norm.weight"))),
post_norm: repl(take(&mut w, &format!("{p}.post_attention_layernorm.weight"))),
post_norm: repl(take(
&mut w,
&format!("{p}.post_attention_layernorm.weight"),
)),
gate_up_proj_wt,
down_proj_wt: row(take(&mut w, &format!("{p}.mlp.down_proj.weight"))),
});
@@ -165,7 +191,10 @@ impl Qwen3 {
let dev = Device::Cuda(device);
assert!(num_stages >= 1);
let num_layers = config.num_layers();
assert!(num_layers % num_stages == 0, "num_layers {num_layers} not divisible by pp {num_stages}");
assert!(
num_layers % num_stages == 0,
"num_layers {num_layers} not divisible by pp {num_stages}"
);
let per_stage = num_layers / num_stages;
let lo = stage * per_stage;
let hi = lo + per_stage;
@@ -173,16 +202,29 @@ impl Qwen3 {
let is_last_stage = stage == num_stages - 1;
let take = |w: &mut HashMap<String, Tensor>, name: &str| -> Tensor {
w.remove(name).unwrap_or_else(|| panic!("missing weight: {name}"))
w.remove(name)
.unwrap_or_else(|| panic!("missing weight: {name}"))
};
let repl = |t: Tensor| -> Tensor { t.to_device(dev) };
// Pre-transpose like the TP path's `col`/`row` do for world==1 (no shard).
let wt = |t: Tensor| -> Tensor { t.to_device(dev).transpose(0, 1).contiguous() };
let placeholder = || Tensor::from_slice(&[bf16::ZERO], &[1, 1]).to_device(dev);
let embed_tokens = if is_first_stage { repl(take(&mut w, "model.embed_tokens.weight")) } else { placeholder() };
let norm = if is_last_stage { repl(take(&mut w, "model.norm.weight")) } else { placeholder() };
let lm_head_t = if is_last_stage { wt(take(&mut w, "lm_head.weight")) } else { placeholder() };
let embed_tokens = if is_first_stage {
repl(take(&mut w, "model.embed_tokens.weight"))
} else {
placeholder()
};
let norm = if is_last_stage {
repl(take(&mut w, "model.norm.weight"))
} else {
placeholder()
};
let lm_head_t = if is_last_stage {
wt(take(&mut w, "lm_head.weight"))
} else {
placeholder()
};
let rope_cache = RopeCache::new(
config.max_seq_len(),
@@ -217,7 +259,10 @@ impl Qwen3 {
o_proj_wt: wt(take(&mut w, &format!("{p}.self_attn.o_proj.weight"))),
q_norm: repl(take(&mut w, &format!("{p}.self_attn.q_norm.weight"))),
k_norm: repl(take(&mut w, &format!("{p}.self_attn.k_norm.weight"))),
post_norm: repl(take(&mut w, &format!("{p}.post_attention_layernorm.weight"))),
post_norm: repl(take(
&mut w,
&format!("{p}.post_attention_layernorm.weight"),
)),
gate_up_proj_wt,
down_proj_wt: wt(take(&mut w, &format!("{p}.mlp.down_proj.weight"))),
});
@@ -252,8 +297,12 @@ impl Qwen3 {
matmul_2d(&x, &self.lm_head_t)
}
pub fn pp_is_first(&self) -> bool { self.is_first_stage }
pub fn pp_is_last(&self) -> bool { self.is_last_stage }
pub fn pp_is_first(&self) -> bool {
self.is_first_stage
}
pub fn pp_is_last(&self) -> bool {
self.is_last_stage
}
/// PP prefill over THIS stage's layers. `x` is `[S, hidden]` (stage 0: from
/// `embed`; otherwise received from the previous stage). Writes K/V for this
@@ -276,7 +325,9 @@ impl Qwen3 {
paged_cache.ensure_capacity(slot, pos_offset + new_tokens);
paged_cache.advance_seq_len(slot, new_tokens);
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect();
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens)
.map(|p| p as u32)
.collect();
for (layer_idx, layer) in self.layers.iter().enumerate() {
let residual = x.clone();
@@ -285,7 +336,9 @@ impl Qwen3 {
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt);
let q = qkv.narrow(1, 0, layer.q_dim).contiguous();
let k = qkv.narrow(1, layer.q_dim, layer.kv_dim).contiguous();
let v = qkv.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous();
let v = qkv
.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim)
.contiguous();
let q = xserv_kernels::reshape_heads_gpu(&q, new_tokens, num_heads, head_dim);
let k = xserv_kernels::reshape_heads_gpu(&k, new_tokens, num_kv_heads, head_dim);
@@ -305,10 +358,12 @@ impl Qwen3 {
let (k_full, v_full) = paged_cache.gather_kv_contiguous(slot, layer_idx);
let attn_out = flash_attention(&q, &k_full, &v_full, true);
let attn_merged = xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
let attn_merged =
xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let (normed, x_new) =
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let residual = x_new.clone();
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt);
@@ -356,7 +411,9 @@ impl Qwen3 {
let qkv_all = matmul_2d(&normed, &layer.qkv_proj_wt);
let q_all = qkv_all.narrow(1, 0, layer.q_dim).contiguous();
let k_all = qkv_all.narrow(1, layer.q_dim, layer.kv_dim).contiguous();
let v_all = qkv_all.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous();
let v_all = qkv_all
.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim)
.contiguous();
let mut q_rows: Vec<Tensor> = Vec::with_capacity(batch);
for b in 0..batch {
@@ -394,14 +451,23 @@ impl Qwen3 {
let v_pool_ptr = paged_cache.v_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
let attn_out = xserv_kernels::paged_decode_attention(
&q_4d, k_pool_ptr, v_pool_ptr, bt_ptr, cl_ptr,
batch, num_heads, num_kv_heads, head_dim, max_blocks,
&q_4d,
k_pool_ptr,
v_pool_ptr,
bt_ptr,
cl_ptr,
batch,
num_heads,
num_kv_heads,
head_dim,
max_blocks,
);
let attn_merged = attn_out.reshape(&[batch, num_heads * head_dim]);
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let (normed, x_new) =
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let residual = x_new.clone();
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt);
@@ -441,7 +507,9 @@ impl Qwen3 {
let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32;
let mut x = embedding(&self.embed_tokens, token_ids);
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect();
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens)
.map(|p| p as u32)
.collect();
for (layer_idx, layer) in self.layers.iter().enumerate() {
let residual = x.clone();
@@ -450,7 +518,9 @@ impl Qwen3 {
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt);
let q = qkv.narrow(1, 0, layer.q_dim).contiguous();
let k = qkv.narrow(1, layer.q_dim, layer.kv_dim).contiguous();
let v = qkv.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous();
let v = qkv
.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim)
.contiguous();
let q = reshape_heads(&q, new_tokens, num_heads, head_dim);
let k = reshape_heads(&k, new_tokens, num_kv_heads, head_dim);
@@ -531,7 +601,9 @@ impl Qwen3 {
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt);
let q_all = qkv.narrow(1, 0, layer.q_dim).contiguous();
let k_all = qkv.narrow(1, layer.q_dim, layer.kv_dim).contiguous();
let v_all = qkv.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous();
let v_all = qkv
.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim)
.contiguous();
// Per-sequence: reshape, qk-norm, RoPE, KV cache, attention, merge
let mut attn_outputs: Vec<Tensor> = Vec::with_capacity(batch);
@@ -583,7 +655,8 @@ impl Qwen3 {
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
// Fused add + rmsnorm
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let (normed, x_new) =
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let residual = x_new.clone();
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt);
@@ -662,13 +735,15 @@ impl Qwen3 {
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt); // [B, (H+2*KV)*D]
let q_dim = num_heads * head_dim;
let kv_dim = num_kv_heads * head_dim;
let q_all = qkv.narrow(1, 0, q_dim); // [B, H*D] (view)
let k_all = qkv.narrow(1, q_dim, kv_dim); // [B, KV*D] (view)
let q_all = qkv.narrow(1, 0, q_dim); // [B, H*D] (view)
let k_all = qkv.narrow(1, q_dim, kv_dim); // [B, KV*D] (view)
let v_all = qkv.narrow(1, q_dim + kv_dim, kv_dim);
// Per-head RMSNorm on contiguous copies (narrow views are strided).
let q_flat = q_all.contiguous().reshape(&[batch * num_heads, head_dim]);
let k_flat = k_all.contiguous().reshape(&[batch * num_kv_heads, head_dim]);
let k_flat = k_all
.contiguous()
.reshape(&[batch * num_kv_heads, head_dim]);
let q_normed = rmsnorm(&q_flat, &layer.q_norm, eps);
let k_normed = rmsnorm(&k_flat, &layer.k_norm, eps);
@@ -688,8 +763,16 @@ impl Qwen3 {
let k_pool_ptr = paged_cache.k_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
let v_pool_ptr = paged_cache.v_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
let attn_out = xserv_kernels::paged_decode_attention(
&q_4d, k_pool_ptr, v_pool_ptr, bt_ptr, cl_ptr,
batch, num_heads, num_kv_heads, head_dim, max_blocks,
&q_4d,
k_pool_ptr,
v_pool_ptr,
bt_ptr,
cl_ptr,
batch,
num_heads,
num_kv_heads,
head_dim,
max_blocks,
);
// attn_out shape [B, H, 1, D] is contiguous-equivalent to [B, H*D].
@@ -697,7 +780,8 @@ impl Qwen3 {
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
self.all_reduce(&attn_proj); // TP: sum partial attention outputs
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let (normed, x_new) =
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let residual = x_new.clone();
// Fused gate+up projection: one GEMV instead of two.
@@ -743,7 +827,9 @@ impl Qwen3 {
paged_cache.advance_seq_len(slot, new_tokens);
let mut x = embedding(&self.embed_tokens, token_ids);
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect();
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens)
.map(|p| p as u32)
.collect();
for (layer_idx, layer) in self.layers.iter().enumerate() {
let residual = x.clone();
@@ -752,7 +838,9 @@ impl Qwen3 {
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt);
let q = qkv.narrow(1, 0, layer.q_dim).contiguous();
let k = qkv.narrow(1, layer.q_dim, layer.kv_dim).contiguous();
let v = qkv.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous();
let v = qkv
.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim)
.contiguous();
let q = xserv_kernels::reshape_heads_gpu(&q, new_tokens, num_heads, head_dim);
let k = xserv_kernels::reshape_heads_gpu(&k, new_tokens, num_kv_heads, head_dim);
@@ -773,11 +861,13 @@ impl Qwen3 {
let (k_full, v_full) = paged_cache.gather_kv_contiguous(slot, layer_idx);
let attn_out = flash_attention(&q, &k_full, &v_full, true);
let attn_merged = xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
let attn_merged =
xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
self.all_reduce(&attn_proj);
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let (normed, x_new) =
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let residual = x_new.clone();
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt);
@@ -805,7 +895,9 @@ impl Qwen3 {
let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32;
let mut x = embedding(&self.embed_tokens, token_ids);
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect();
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens)
.map(|p| p as u32)
.collect();
for (layer_idx, layer) in self.layers.iter().enumerate() {
let residual = x.clone();
@@ -814,7 +906,9 @@ impl Qwen3 {
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt);
let q = qkv.narrow(1, 0, layer.q_dim).contiguous();
let k = qkv.narrow(1, layer.q_dim, layer.kv_dim).contiguous();
let v = qkv.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous();
let v = qkv
.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim)
.contiguous();
let q = xserv_kernels::reshape_heads_gpu(&q, new_tokens, num_heads, head_dim);
let k = xserv_kernels::reshape_heads_gpu(&k, new_tokens, num_kv_heads, head_dim);
@@ -834,10 +928,12 @@ impl Qwen3 {
let (k_full, v_full) = cache.get_kv_len(layer_idx, pos_offset + new_tokens);
let attn_out = flash_attention(&q, &k_full, &v_full, true);
let attn_merged = xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
let attn_merged =
xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let (normed, x_new) =
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let residual = x_new.clone();
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt);
@@ -856,28 +952,33 @@ impl Qwen3 {
/// Extract weight pointers for CUDA Graph capture.
pub fn layer_weight_ptrs(&self) -> Vec<crate::decode_graph::LayerWeightPtrs> {
self.layers.iter().map(|l| crate::decode_graph::LayerWeightPtrs {
input_norm: l.input_norm.data_ptr() as *const std::ffi::c_void,
q_proj_wt: l.q_proj_wt().data_ptr() as *const std::ffi::c_void,
k_proj_wt: l.k_proj_wt().data_ptr() as *const std::ffi::c_void,
v_proj_wt: l.v_proj_wt().data_ptr() as *const std::ffi::c_void,
o_proj_wt: l.o_proj_wt.data_ptr() as *const std::ffi::c_void,
q_norm: l.q_norm.data_ptr() as *const std::ffi::c_void,
k_norm: l.k_norm.data_ptr() as *const std::ffi::c_void,
post_norm: l.post_norm.data_ptr() as *const std::ffi::c_void,
gate_proj_wt: l.gate_proj_wt().data_ptr() as *const std::ffi::c_void,
up_proj_wt: l.up_proj_wt().data_ptr() as *const std::ffi::c_void,
down_proj_wt: l.down_proj_wt.data_ptr() as *const std::ffi::c_void,
}).collect()
self.layers
.iter()
.map(|l| crate::decode_graph::LayerWeightPtrs {
input_norm: l.input_norm.data_ptr() as *const std::ffi::c_void,
q_proj_wt: l.q_proj_wt().data_ptr() as *const std::ffi::c_void,
k_proj_wt: l.k_proj_wt().data_ptr() as *const std::ffi::c_void,
v_proj_wt: l.v_proj_wt().data_ptr() as *const std::ffi::c_void,
o_proj_wt: l.o_proj_wt.data_ptr() as *const std::ffi::c_void,
q_norm: l.q_norm.data_ptr() as *const std::ffi::c_void,
k_norm: l.k_norm.data_ptr() as *const std::ffi::c_void,
post_norm: l.post_norm.data_ptr() as *const std::ffi::c_void,
gate_proj_wt: l.gate_proj_wt().data_ptr() as *const std::ffi::c_void,
up_proj_wt: l.up_proj_wt().data_ptr() as *const std::ffi::c_void,
down_proj_wt: l.down_proj_wt.data_ptr() as *const std::ffi::c_void,
})
.collect()
}
/// Get pointers needed for CUDA Graph capture.
pub fn graph_capture_ptrs(&self) -> (
*const std::ffi::c_void, // norm weight
*const std::ffi::c_void, // lm_head_t
*const std::ffi::c_void, // embed_tokens
*const std::ffi::c_void, // rope cos
*const std::ffi::c_void, // rope sin
pub fn graph_capture_ptrs(
&self,
) -> (
*const std::ffi::c_void, // norm weight
*const std::ffi::c_void, // lm_head_t
*const std::ffi::c_void, // embed_tokens
*const std::ffi::c_void, // rope cos
*const std::ffi::c_void, // rope sin
) {
(
self.norm.data_ptr() as *const std::ffi::c_void,
@@ -895,11 +996,16 @@ impl Qwen3 {
/// (column-parallel split: split the OUTPUT dim). `world==1` returns the whole.
/// Input must be a contiguous CPU (or device) BF16 tensor.
fn shard_rows(t: &Tensor, rank: usize, world: usize) -> Tensor {
if world == 1 { return t.clone(); }
if world == 1 {
return t.clone();
}
let shape = t.shape();
assert_eq!(shape.len(), 2, "shard_rows expects 2D weight");
let (rows, cols) = (shape[0], shape[1]);
assert!(rows % world == 0, "rows {rows} not divisible by world {world}");
assert!(
rows % world == 0,
"rows {rows} not divisible by world {world}"
);
let local = rows / world;
let host = t.to_device(Device::Cpu);
let data = host.as_slice::<bf16>();
@@ -911,11 +1017,16 @@ fn shard_rows(t: &Tensor, rank: usize, world: usize) -> Tensor {
/// Keep this rank's column-block of a 2D `[rows, cols]` BF16 tensor (row-parallel
/// split: split the INPUT dim). Strided copy. `world==1` returns the whole.
fn shard_cols(t: &Tensor, rank: usize, world: usize) -> Tensor {
if world == 1 { return t.clone(); }
if world == 1 {
return t.clone();
}
let shape = t.shape();
assert_eq!(shape.len(), 2, "shard_cols expects 2D weight");
let (rows, cols) = (shape[0], shape[1]);
assert!(cols % world == 0, "cols {cols} not divisible by world {world}");
assert!(
cols % world == 0,
"cols {cols} not divisible by world {world}"
);
let local = cols / world;
let c0 = rank * local;
let host = t.to_device(Device::Cpu);
@@ -1009,7 +1120,9 @@ fn transpose_from_rope(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: u
}
fn repeat_kv(x: &Tensor, n_rep: usize) -> Tensor {
if n_rep == 1 { return x.clone(); }
if n_rep == 1 {
return x.clone();
}
let kv_heads = x.shape()[1];
let seq_len = x.shape()[2];
let head_dim = x.shape()[3];
@@ -1065,11 +1178,16 @@ fn concat_rows(rows: &[Tensor]) -> Tensor {
let src_buf = row.storage().gpu_buffer();
let src_offset = row.offset() * elem_size;
let dst_offset = b * row_bytes;
out_buf.copy_from_device_at(src_buf, src_offset, dst_offset, row_bytes).unwrap();
out_buf
.copy_from_device_at(src_buf, src_offset, dst_offset, row_bytes)
.unwrap();
}
// Wrap in a Tensor
let device_id = match device { Device::Cuda(id) => id, _ => panic!("expected CUDA device") };
let device_id = match device {
Device::Cuda(id) => id,
_ => panic!("expected CUDA device"),
};
unsafe {
crate::kv_cache::tensor_from_gpu_buffer_pub(out_buf, &[batch, cols], dtype, device_id)
}
@@ -1082,12 +1200,15 @@ fn cat_cols(tensors: &[&Tensor]) -> Tensor {
let dtype = tensors[0].dtype();
let device = tensors[0].device();
let elem = dtype.size_bytes();
let total_cols: usize = tensors.iter().map(|t| {
assert_eq!(t.ndim(), 2);
assert_eq!(t.shape()[0], rows);
assert!(t.is_contiguous());
t.shape()[1]
}).sum();
let total_cols: usize = tensors
.iter()
.map(|t| {
assert_eq!(t.ndim(), 2);
assert_eq!(t.shape()[0], rows);
assert!(t.is_contiguous());
t.shape()[1]
})
.sum();
let out = Tensor::empty(&[rows, total_cols], dtype, device);
let dst_base = out.data_ptr() as *mut u8;
for r in 0..rows {
@@ -1126,7 +1247,9 @@ pub fn sample_greedy(logits: &Tensor) -> u32 {
let seq_len = logits.shape()[0];
let data = logits_cpu.as_slice::<bf16>();
let last = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size];
last.iter().enumerate()
last.iter()
.enumerate()
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
.map(|(i, _)| i as u32).unwrap()
.map(|(i, _)| i as u32)
.unwrap()
}

View File

@@ -11,7 +11,11 @@ pub struct SamplingParams {
impl Default for SamplingParams {
fn default() -> Self {
Self { temperature: 0.0, top_k: 0, top_p: 1.0 }
Self {
temperature: 0.0,
top_k: 0,
top_p: 1.0,
}
}
}
@@ -134,9 +138,14 @@ pub fn sample_greedy_penalized(logits: &Tensor, recent: &[u32], penalty: f32) ->
let seq_len = logits.shape()[0];
let logits_cpu = logits.to_device(Device::Cpu);
let mut last_row: Vec<f32> = match logits.dtype() {
DType::F32 => logits_cpu.as_slice::<f32>()[(seq_len - 1) * vocab_size..seq_len * vocab_size].to_vec(),
DType::BF16 => logits_cpu.as_slice::<bf16>()[(seq_len - 1) * vocab_size..seq_len * vocab_size]
.iter().map(|v| v.to_f32()).collect(),
DType::F32 => {
logits_cpu.as_slice::<f32>()[(seq_len - 1) * vocab_size..seq_len * vocab_size].to_vec()
}
DType::BF16 => logits_cpu.as_slice::<bf16>()
[(seq_len - 1) * vocab_size..seq_len * vocab_size]
.iter()
.map(|v| v.to_f32())
.collect(),
_ => panic!("unsupported dtype for sampling: {:?}", logits.dtype()),
};
if penalty > 1.0 {

View File

@@ -72,7 +72,10 @@ impl ChatTemplate {
let source = std::fs::read_to_string(&jinja_path)
.unwrap_or_else(|e| panic!("failed to read {}: {e}", jinja_path.display()));
eprintln!("[chat-template] loaded from {}", jinja_path.display());
return Self { source, model_type: model_type.to_string() };
return Self {
source,
model_type: model_type.to_string(),
};
}
// 2. Try tokenizer_config.json → chat_template field
@@ -82,7 +85,10 @@ impl ChatTemplate {
if let Ok(v) = serde_json::from_str::<serde_json::Value>(&data) {
if let Some(ct) = v.get("chat_template").and_then(|v| v.as_str()) {
eprintln!("[chat-template] loaded from tokenizer_config.json");
return Self { source: ct.to_string(), model_type: model_type.to_string() };
return Self {
source: ct.to_string(),
model_type: model_type.to_string(),
};
}
}
}
@@ -90,7 +96,10 @@ impl ChatTemplate {
// 3. No template found — use empty source, will fall back to hardcoded
eprintln!("[chat-template] no Jinja template found, using hardcoded fallback");
Self { source: String::new(), model_type: model_type.to_string() }
Self {
source: String::new(),
model_type: model_type.to_string(),
}
}
pub fn render(&self, messages: &[Message]) -> String {
@@ -206,7 +215,10 @@ fn build_prompt_gpt_oss(messages: &[Message]) -> String {
prompt.push_str("<|start|>system<|message|>");
prompt.push_str("You are ChatGPT, a large language model trained by OpenAI.\n");
prompt.push_str("Knowledge cutoff: 2024-06\n");
prompt.push_str(&format!("Current date: {}\n\n", strftime_now("%Y-%m-%d".to_string())));
prompt.push_str(&format!(
"Current date: {}\n\n",
strftime_now("%Y-%m-%d".to_string())
));
prompt.push_str("Reasoning: low\n\n");
prompt.push_str("# Valid channels: analysis, commentary, final. Channel must be included for every message.");
prompt.push_str("<|end|>");
@@ -334,13 +346,11 @@ async fn chat_non_stream(state: Arc<AppState>, req: ChatRequest) -> Response {
"completion_tokens": completion_token_count,
"total_tokens": prompt_token_count + completion_token_count
}
})).into_response()
}))
.into_response()
}
fn chat_stream(
state: Arc<AppState>,
req: ChatRequest,
) -> Response {
fn chat_stream(state: Arc<AppState>, req: ChatRequest) -> Response {
let id = format!("chatcmpl-{}", Uuid::new_v4());
let model_name = state.model_name.clone();
let created = unix_timestamp();
@@ -356,7 +366,8 @@ fn chat_stream(
if prompt_tokens.len() >= max_seq_len {
return bad_request(format!(
"prompt is {} tokens, exceeds max_seq_len {}",
prompt_tokens.len(), max_seq_len
prompt_tokens.len(),
max_seq_len
));
}
let max_tokens = req.max_tokens.min(max_seq_len - prompt_tokens.len());
@@ -413,7 +424,9 @@ fn chat_stream(
}
});
Sse::new(ReceiverStream::new(sse_rx)).keep_alive(KeepAlive::default()).into_response()
Sse::new(ReceiverStream::new(sse_rx))
.keep_alive(KeepAlive::default())
.into_response()
}
fn validate_request(req: &ChatRequest, model_name: &str) -> Option<Response> {
@@ -436,8 +449,13 @@ fn validate_request(req: &ChatRequest, model_name: &str) -> Option<Response> {
/// prior handler panicked) and returns a clean 503 instead of panicking when the
/// engine thread is gone, so one dead engine doesn't cascade into every request.
fn submit_to_engine(state: &AppState, req: GenerateRequest) -> Result<(), Response> {
let sender = state.engine_sender.lock().unwrap_or_else(|e| e.into_inner());
sender.send(req).map_err(|_| service_unavailable("inference engine is not available"))
let sender = state
.engine_sender
.lock()
.unwrap_or_else(|e| e.into_inner());
sender
.send(req)
.map_err(|_| service_unavailable("inference engine is not available"))
}
fn service_unavailable(message: impl Into<String>) -> Response {

View File

@@ -1,10 +1,10 @@
use std::collections::VecDeque;
use std::path::Path;
use std::sync::mpsc;
use std::sync::Once;
use std::sync::mpsc;
use std::time::Instant;
use xserv_model::{ModelConfig, PagedKVCache, Qwen3, SamplingParams, sample, BLOCK_SIZE};
use xserv_model::loader;
use xserv_model::{BLOCK_SIZE, ModelConfig, PagedKVCache, Qwen3, SamplingParams, sample};
use xserv_tensor::{DType, Device};
use xserv_tokenizer::Tokenizer;
@@ -109,12 +109,23 @@ impl Engine {
(total_blocks * bytes_per_block) as f64 / 1e9,
info.free_memory as f64 / 1e9,
);
Self { model, config, tokenizer, max_batch_size, max_seq_len, paged_cache }
Self {
model,
config,
tokenizer,
max_batch_size,
max_seq_len,
paged_cache,
}
}
pub fn tokenizer(&self) -> &Tokenizer { &self.tokenizer }
pub fn tokenizer(&self) -> &Tokenizer {
&self.tokenizer
}
pub fn max_seq_len(&self) -> usize { self.max_seq_len }
pub fn max_seq_len(&self) -> usize {
self.max_seq_len
}
/// Main scheduler loop. Receives requests from channel, manages concurrent sequences.
///
@@ -134,7 +145,8 @@ impl Engine {
loop {
// Step 1: Remove finished sequences and return their slots.
let finished_slots: Vec<usize> = running.iter()
let finished_slots: Vec<usize> = running
.iter()
.filter(|s| is_finished(s))
.filter_map(|s| s.seq_slot)
.collect();
@@ -147,10 +159,16 @@ impl Engine {
// room (oldest first). They resume decoding from where they paused.
while running.len() < self.max_batch_size && !swapped.is_empty() {
let slot = swapped[0].seq_slot.expect("swapped slot");
if !self.paged_cache.can_swap_in(slot) { break; }
if !self.paged_cache.can_swap_in(slot) {
break;
}
self.paged_cache.swap_in(slot).expect("swap_in");
let seq = swapped.remove(0);
eprintln!("[scheduler] swapped in seq {} ({} blocks)", seq.id, self.paged_cache.block_count(slot));
eprintln!(
"[scheduler] swapped in seq {} ({} blocks)",
seq.id,
self.paged_cache.block_count(slot)
);
running.push(seq);
}
@@ -161,14 +179,22 @@ impl Engine {
let mut avail = self.paged_cache.free_blocks();
let decode_reserve = running.len();
while running.len() < self.max_batch_size {
let Some(front) = waiting.front() else { break; };
let Some(front) = waiting.front() else {
break;
};
let prompt_blocks = front.prompt_tokens.len().div_ceil(BLOCK_SIZE).max(1);
if avail < prompt_blocks + decode_reserve { break; }
if avail < prompt_blocks + decode_reserve {
break;
}
let free_slot = (0..self.paged_cache.max_seqs())
.find(|&s| self.paged_cache.is_slot_free(s));
let Some(slot) = free_slot else { break; };
let Some(slot) = free_slot else {
break;
};
let mut seq = waiting.pop_front().unwrap();
self.paged_cache.register_sequence(slot).expect("register paged slot");
self.paged_cache
.register_sequence(slot)
.expect("register paged slot");
seq.seq_slot = Some(slot);
running.push(seq);
avail -= prompt_blocks; // projected free after this seq prefills
@@ -199,7 +225,9 @@ impl Engine {
if !seq.prefilled {
let slot = seq.seq_slot.expect("slot");
let logits = self.model.forward_prefill_paged(
&seq.prompt_tokens, slot, &mut self.paged_cache,
&seq.prompt_tokens,
slot,
&mut self.paged_cache,
);
let next = sample(&logits, &seq.sampling);
seq.generated_tokens.push(next);
@@ -219,13 +247,18 @@ impl Engine {
&& !newly_prefilled.contains(&running[p].id)
&& running[p].seq_slot.is_some()
});
let Some(pos) = victim else { break; };
let Some(pos) = victim else {
break;
};
let seq = running.remove(pos);
let slot = seq.seq_slot.unwrap();
if self.paged_cache.can_swap_out(slot) {
let nblocks = self.paged_cache.block_count(slot);
self.paged_cache.swap_out(slot).expect("swap_out");
eprintln!("[scheduler] preempt: swapped out seq {} ({nblocks} blocks) to host", seq.id);
eprintln!(
"[scheduler] preempt: swapped out seq {} ({nblocks} blocks) to host",
seq.id
);
swapped.push(seq);
needed = decode_block_need(&self.paged_cache, &running, &newly_prefilled);
} else {
@@ -235,7 +268,9 @@ impl Engine {
}
// Step 5c: Batched paged decode for the surviving prefilled sequences.
let decode_indices: Vec<usize> = running.iter().enumerate()
let decode_indices: Vec<usize> = running
.iter()
.enumerate()
.filter(|(_, s)| s.prefilled && !newly_prefilled.contains(&s.id))
.map(|(i, _)| i)
.collect();
@@ -246,25 +281,32 @@ impl Engine {
eprintln!("[scheduler] paged decode active");
});
let tokens: Vec<u32> = decode_indices.iter()
let tokens: Vec<u32> = decode_indices
.iter()
.map(|&i| *running[i].generated_tokens.last().unwrap())
.collect();
let positions: Vec<usize> = decode_indices.iter()
let positions: Vec<usize> = decode_indices
.iter()
.map(|&i| self.paged_cache.seq_len(running[i].seq_slot.unwrap()))
.collect();
let slots: Vec<usize> = decode_indices.iter()
let slots: Vec<usize> = decode_indices
.iter()
.map(|&i| running[i].seq_slot.unwrap())
.collect();
let logits = self.model.forward_decode_paged(
&tokens, &positions, &slots, &mut self.paged_cache,
&tokens,
&positions,
&slots,
&mut self.paged_cache,
);
// Fast path: every active sequence is greedy → run argmax on
// the GPU and only D2H the chosen token ids (a few bytes per
// sequence) instead of the full [B, vocab_size] BF16 logits
// (~1.2 MB for B=4, Qwen3 vocab=152K).
let all_greedy = decode_indices.iter()
let all_greedy = decode_indices
.iter()
.all(|&i| running[i].sampling.temperature == 0.0);
if all_greedy {
let next_ids = xserv_kernels::argmax_bf16_to_host(&logits);
@@ -285,11 +327,15 @@ impl Engine {
let row_start = j * vocab_size;
let row_logits = &data[row_start..row_start + vocab_size];
let next = if running[i].sampling.temperature == 0.0 {
row_logits.iter().enumerate()
row_logits
.iter()
.enumerate()
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
.map(|(idx, _)| idx as u32).unwrap()
.map(|(idx, _)| idx as u32)
.unwrap()
} else {
let row_tensor = xserv_tensor::Tensor::from_slice(row_logits, &[1, vocab_size]);
let row_tensor =
xserv_tensor::Tensor::from_slice(row_logits, &[1, vocab_size]);
sample(&row_tensor, &running[i].sampling)
};
running[i].generated_tokens.push(next);
@@ -334,7 +380,8 @@ impl Engine {
/// Total additional GPU blocks the next decode step needs across all
/// currently-decoding (prefilled, not just-prefilled) sequences.
fn decode_block_need(paged: &PagedKVCache, running: &[Sequence], newly_prefilled: &[u64]) -> usize {
running.iter()
running
.iter()
.filter(|s| s.prefilled && !newly_prefilled.contains(&s.id))
.filter_map(|s| s.seq_slot)
.map(|slot| paged.additional_blocks_needed(slot, 1))
@@ -372,8 +419,12 @@ fn send_token_if_nonempty(seq: &Sequence, text: String) {
}
fn is_finished(seq: &Sequence) -> bool {
if seq.generated_tokens.is_empty() { return false; }
if seq.generated_tokens.is_empty() {
return false;
}
let last = *seq.generated_tokens.last().unwrap();
if seq.generated_tokens.len() >= seq.max_tokens { return true; }
if seq.generated_tokens.len() >= seq.max_tokens {
return true;
}
seq.sender.is_closed() || seq.eos_token_id == Some(last)
}

View File

@@ -3,10 +3,13 @@ mod engine;
mod pp_engine;
mod tp_engine;
use axum::{routing::{get, post}, Extension, Router};
use std::path::PathBuf;
use std::sync::{mpsc, Arc, Mutex};
use axum::{
Extension, Router,
routing::{get, post},
};
use engine::GenerateRequest;
use std::path::PathBuf;
use std::sync::{Arc, Mutex, mpsc};
use xserv_model::ModelConfig;
pub struct AppState {
@@ -21,40 +24,48 @@ pub struct AppState {
async fn main() {
let args: Vec<String> = std::env::args().collect();
if args.len() < 2 {
eprintln!("Usage: xserv-server <model-dir> [--port PORT] [--max-batch N] [--max-seq-len N] [--swap-space-gb N] [--tp N] [--pp N]");
eprintln!(
"Usage: xserv-server <model-dir> [--port PORT] [--max-batch N] [--max-seq-len N] [--swap-space-gb N] [--tp N] [--pp N]"
);
std::process::exit(1);
}
let model_dir = PathBuf::from(&args[1]);
let port: u16 = args.iter()
let port: u16 = args
.iter()
.position(|a| a == "--port")
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse().ok())
.unwrap_or(8080);
let max_batch: usize = args.iter()
let max_batch: usize = args
.iter()
.position(|a| a == "--max-batch")
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse().ok())
.unwrap_or(4)
.max(1);
let requested_max_seq_len: usize = args.iter()
let requested_max_seq_len: usize = args
.iter()
.position(|a| a == "--max-seq-len")
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse().ok())
.unwrap_or(2048)
.max(1);
let swap_space_gb: usize = args.iter()
let swap_space_gb: usize = args
.iter()
.position(|a| a == "--swap-space-gb")
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse().ok())
.unwrap_or(8);
let tp: usize = args.iter()
let tp: usize = args
.iter()
.position(|a| a == "--tp")
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse().ok())
.unwrap_or(1)
.max(1);
let pp: usize = args.iter()
let pp: usize = args
.iter()
.position(|a| a == "--pp")
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse().ok())
@@ -69,7 +80,9 @@ async fn main() {
// tp=1 (single-rank world) so quantized models can serve on one GPU.
let is_gpt_oss = model_config.model_type.as_deref() == Some("gpt_oss");
if pp > 1 && is_gpt_oss {
eprintln!("gpt-oss is not supported by the pipeline-parallel engine (Qwen3 only); use --tp instead");
eprintln!(
"gpt-oss is not supported by the pipeline-parallel engine (Qwen3 only); use --tp instead"
);
std::process::exit(1);
}
let model_max_seq_len = model_config.max_seq_len();
@@ -84,7 +97,8 @@ async fn main() {
);
}
let model_name = model_dir.file_name()
let model_name = model_dir
.file_name()
.map(|n| n.to_string_lossy().to_string())
.unwrap_or_else(|| "unknown".to_string());
@@ -99,7 +113,12 @@ async fn main() {
// Pipeline-parallel path: stage-0 coordinator + worker stage threads.
pp_engine::run_pp(&model_dir_clone, pp, max_seq_len, rx);
} else if tp <= 1 && !is_gpt_oss {
let mut engine = engine::Engine::load_with_swap(&model_dir_clone, max_batch, max_seq_len, swap_space_gb);
let mut engine = engine::Engine::load_with_swap(
&model_dir_clone,
max_batch,
max_seq_len,
swap_space_gb,
);
engine.run(rx);
} else {
// Tensor-parallel path: rank-0 coordinator + worker rank threads.

View File

@@ -15,15 +15,15 @@
use std::ffi::c_void;
use std::path::{Path, PathBuf};
use std::sync::mpsc;
use std::sync::Arc;
use std::sync::mpsc;
use std::thread;
use half::bf16;
use xserv_distributed::{PpContext, UniqueId};
use xserv_model::loader;
use xserv_model::sampling::SamplingParams;
use xserv_model::{sample, ModelConfig, PagedKVCache, Qwen3, BLOCK_SIZE};
use xserv_model::{BLOCK_SIZE, ModelConfig, PagedKVCache, Qwen3, sample};
use xserv_tensor::{DType, Device, Tensor};
use xserv_tokenizer::Tokenizer;
@@ -38,9 +38,16 @@ enum PpCommand {
Free(usize),
/// Receive `[n_tokens, hidden]` from the previous stage, run this stage's
/// layers; if last stage, sample with `sampling` and return the token.
Prefill { n_tokens: usize, slot: usize, sampling: SamplingParams },
Prefill {
n_tokens: usize,
slot: usize,
sampling: SamplingParams,
},
/// Receive `[1, hidden]`, run this stage's layers; last stage samples.
Decode { slot: usize, sampling: SamplingParams },
Decode {
slot: usize,
sampling: SamplingParams,
},
Shutdown,
}
@@ -76,9 +83,21 @@ fn build_stage(
let max_blocks_per_seq = max_seq_len.div_ceil(BLOCK_SIZE);
let total_blocks = max_blocks_per_seq + 8; // v1 serial: one active sequence
let cache = PagedKVCache::new(
&stage_config, total_blocks, 0, 4, max_blocks_per_seq, DType::BF16, device,
&stage_config,
total_blocks,
0,
4,
max_blocks_per_seq,
DType::BF16,
device,
);
StageCtx { model, cache, pp, hidden: config.hidden(), device }
StageCtx {
model,
cache,
pp,
hidden: config.hidden(),
device,
}
}
/// Allocate a zeroed `[n, hidden]` device tensor and receive into it from `peer`.
@@ -110,7 +129,15 @@ fn worker_loop(
ack_tx: mpsc::Sender<()>,
token_tx: mpsc::Sender<u32>,
) {
let mut sc = build_stage(&model_dir, &config, stage, world, stage as u32, max_seq_len, id);
let mut sc = build_stage(
&model_dir,
&config,
stage,
world,
stage as u32,
max_seq_len,
id,
);
let is_last = stage == world - 1;
let prev = stage - 1;
let next = stage + 1;
@@ -125,7 +152,11 @@ fn worker_loop(
sc.cache.free_sequence(slot);
let _ = ack_tx.send(());
}
PpCommand::Prefill { n_tokens, slot, sampling } => {
PpCommand::Prefill {
n_tokens,
slot,
sampling,
} => {
let x = recv_hidden(&sc, n_tokens, prev);
let x = sc.model.forward_layers_prefill(x, slot, &mut sc.cache);
if is_last {
@@ -155,7 +186,12 @@ fn worker_loop(
/// Run the PP coordinator (stage 0) on the calling thread. Spawns worker stages
/// 1..world and consumes generation requests from `rx`.
pub fn run_pp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Receiver<GenerateRequest>) {
pub fn run_pp(
model_dir: &Path,
world: usize,
max_seq_len: usize,
rx: mpsc::Receiver<GenerateRequest>,
) {
assert!(world >= 2, "run_pp requires world >= 2");
let config = ModelConfig::from_file(&model_dir.join("config.json"));
assert!(
@@ -179,7 +215,17 @@ pub fn run_pp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
let model_dir = model_dir.to_path_buf();
let config = config.clone();
thread::spawn(move || {
worker_loop(stage, world, id, model_dir, config, max_seq_len, ctx_rx, ack_tx, token_tx);
worker_loop(
stage,
world,
id,
model_dir,
config,
max_seq_len,
ctx_rx,
ack_tx,
token_tx,
);
});
}
@@ -207,11 +253,14 @@ pub fn run_pp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
wait_acks(&ack_rx);
// Prefill: embed prompt, run stage-0 layers, push hidden into the pipe.
broadcast(&cmd_txs, PpCommand::Prefill {
n_tokens: req.prompt_tokens.len(),
slot,
sampling: req.sampling.clone(),
});
broadcast(
&cmd_txs,
PpCommand::Prefill {
n_tokens: req.prompt_tokens.len(),
slot,
sampling: req.sampling.clone(),
},
);
let x = sc.model.embed(&req.prompt_tokens);
let x = sc.model.forward_layers_prefill(x, slot, &mut sc.cache);
send_hidden(&sc, &x, next_peer);
@@ -228,7 +277,13 @@ pub fn run_pp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
if generated >= req.max_tokens {
break "length";
}
broadcast(&cmd_txs, PpCommand::Decode { slot, sampling: req.sampling.clone() });
broadcast(
&cmd_txs,
PpCommand::Decode {
slot,
sampling: req.sampling.clone(),
},
);
let x = sc.model.embed(&[next]);
let x = sc.model.forward_layers_decode(x, &[slot], &mut sc.cache);
send_hidden(&sc, &x, next_peer);
@@ -239,9 +294,14 @@ pub fn run_pp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
let tail = tokenizer.flush_decode_stream(&mut decode_buf);
if !tail.is_empty() {
let _ = req.sender.blocking_send(GenerateEvent::Token { id: next, text: tail });
let _ = req.sender.blocking_send(GenerateEvent::Token {
id: next,
text: tail,
});
}
let _ = req.sender.blocking_send(GenerateEvent::Done { finish_reason: finish.to_string() });
let _ = req.sender.blocking_send(GenerateEvent::Done {
finish_reason: finish.to_string(),
});
broadcast(&cmd_txs, PpCommand::Free(slot));
sc.cache.free_sequence(slot);
@@ -258,6 +318,8 @@ fn emit_text(tokenizer: &Tokenizer, req: &GenerateRequest, token_id: u32, buf: &
}
let text = tokenizer.decode_token_stream(token_id, buf);
if !text.is_empty() {
let _ = req.sender.blocking_send(GenerateEvent::Token { id: token_id, text });
let _ = req
.sender
.blocking_send(GenerateEvent::Token { id: token_id, text });
}
}

View File

@@ -13,13 +13,16 @@
//! work; the single-GPU `Engine` still handles TP=1.
use std::path::{Path, PathBuf};
use std::sync::mpsc;
use std::sync::Arc;
use std::sync::mpsc;
use std::thread;
use xserv_distributed::{TpContext, UniqueId};
use xserv_model::loader;
use xserv_model::{sample, sample_greedy_penalized, GptOss, GraphedGptOssDecoder, ModelConfig, PagedKVCache, Qwen3, BLOCK_SIZE};
use xserv_model::{
BLOCK_SIZE, GptOss, GraphedGptOssDecoder, ModelConfig, PagedKVCache, Qwen3, sample,
sample_greedy_penalized,
};
use xserv_tensor::{DType, Device, Tensor};
use xserv_tokenizer::Tokenizer;
@@ -29,8 +32,15 @@ use crate::engine::{GenerateEvent, GenerateRequest};
enum TpCommand {
Register(usize),
Free(usize),
Prefill { tokens: Vec<u32>, slot: usize },
Decode { tokens: Vec<u32>, positions: Vec<usize>, slots: Vec<usize> },
Prefill {
tokens: Vec<u32>,
slot: usize,
},
Decode {
tokens: Vec<u32>,
positions: Vec<usize>,
slots: Vec<usize>,
},
Shutdown,
}
@@ -40,14 +50,25 @@ enum TpModel {
}
impl TpModel {
fn forward_prefill_paged(&self, tokens: &[u32], slot: usize, cache: &mut PagedKVCache) -> Tensor {
fn forward_prefill_paged(
&self,
tokens: &[u32],
slot: usize,
cache: &mut PagedKVCache,
) -> Tensor {
match self {
TpModel::Qwen3(m) => m.forward_prefill_paged(tokens, slot, cache),
TpModel::GptOss(m) => m.forward_prefill_paged(tokens, slot, cache),
}
}
fn forward_decode_paged(&self, tokens: &[u32], positions: &[usize], slots: &[usize], cache: &mut PagedKVCache) -> Tensor {
fn forward_decode_paged(
&self,
tokens: &[u32],
positions: &[usize],
slots: &[usize],
cache: &mut PagedKVCache,
) -> Tensor {
match self {
TpModel::Qwen3(m) => m.forward_decode_paged(tokens, positions, slots, cache),
TpModel::GptOss(m) => m.forward_decode_paged(tokens, positions, slots, cache),
@@ -65,8 +86,12 @@ struct RankCtx {
/// (lazy capture, replay thereafter); everything else runs eager.
fn rank_decode(rc: &mut RankCtx, tokens: &[u32], positions: &[usize], slots: &[usize]) -> Tensor {
match &rc.model {
TpModel::GptOss(m) => rc.decoder.decode(m, tokens, positions, slots, &mut rc.cache),
TpModel::Qwen3(_) => rc.model.forward_decode_paged(tokens, positions, slots, &mut rc.cache),
TpModel::GptOss(m) => rc
.decoder
.decode(m, tokens, positions, slots, &mut rc.cache),
TpModel::Qwen3(_) => rc
.model
.forward_decode_paged(tokens, positions, slots, &mut rc.cache),
}
}
@@ -81,17 +106,42 @@ fn build_rank(
) -> RankCtx {
let weights = loader::load_model_dir(model_dir, Device::Cpu);
let model = if config.is_moe() {
TpModel::GptOss(GptOss::from_weights_tp(config.clone(), weights, rank, world, device, tp))
TpModel::GptOss(GptOss::from_weights_tp(
config.clone(),
weights,
rank,
world,
device,
tp,
))
} else {
TpModel::Qwen3(Qwen3::from_weights_tp(config.clone(), weights, rank, world, device, tp))
TpModel::Qwen3(Qwen3::from_weights_tp(
config.clone(),
weights,
rank,
world,
device,
tp,
))
};
let local_kv = config.num_kv_heads() / world;
let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
let total_blocks = max_blocks_per_seq + 8;
let cache = PagedKVCache::new_tp(
config, local_kv, total_blocks, 0, 4, max_blocks_per_seq, DType::BF16, device,
config,
local_kv,
total_blocks,
0,
4,
max_blocks_per_seq,
DType::BF16,
device,
);
RankCtx { model, cache, decoder: GraphedGptOssDecoder::new() }
RankCtx {
model,
cache,
decoder: GraphedGptOssDecoder::new(),
}
}
fn worker_loop(
@@ -105,7 +155,15 @@ fn worker_loop(
ack_tx: mpsc::Sender<()>,
) {
let tp = Arc::new(TpContext::init(rank, world, id, rank as u32));
let mut rc = build_rank(&model_dir, &config, rank, world, rank as u32, max_seq_len, Some(tp));
let mut rc = build_rank(
&model_dir,
&config,
rank,
world,
rank as u32,
max_seq_len,
Some(tp),
);
while let Ok(cmd) = cmd_rx.recv() {
match cmd {
TpCommand::Register(slot) => {
@@ -115,7 +173,11 @@ fn worker_loop(
TpCommand::Prefill { tokens, slot } => {
let _ = rc.model.forward_prefill_paged(&tokens, slot, &mut rc.cache);
}
TpCommand::Decode { tokens, positions, slots } => {
TpCommand::Decode {
tokens,
positions,
slots,
} => {
let _ = rank_decode(&mut rc, &tokens, &positions, &slots);
}
TpCommand::Shutdown => {
@@ -129,7 +191,12 @@ fn worker_loop(
/// Run the TP coordinator (rank 0) on the calling thread. Spawns worker ranks
/// internally and consumes generation requests from `rx`.
pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Receiver<GenerateRequest>) {
pub fn run_tp(
model_dir: &Path,
world: usize,
max_seq_len: usize,
rx: mpsc::Receiver<GenerateRequest>,
) {
// world=1 is a valid single-rank configuration (gpt-oss has no
// single-GPU engine path; NCCL init and all_reduce no-op at world=1).
assert!(world >= 1, "run_tp requires world >= 1");
@@ -152,7 +219,16 @@ pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
let model_dir = model_dir.to_path_buf();
let config = config.clone();
thread::spawn(move || {
worker_loop(rank, world, id, model_dir, config, max_seq_len, ctx_rx, ack_tx);
worker_loop(
rank,
world,
id,
model_dir,
config,
max_seq_len,
ctx_rx,
ack_tx,
);
});
}
@@ -165,10 +241,14 @@ pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
// models loop under pure greedy when numerics diverge from the reference).
// Off by default; XSERV_REP_PENALTY>1 enables it over the last
// XSERV_REP_WINDOW generated tokens. Applied only on the greedy path.
let rep_penalty: f32 = std::env::var("XSERV_REP_PENALTY").ok()
.and_then(|s| s.parse().ok()).unwrap_or(1.0);
let rep_window: usize = std::env::var("XSERV_REP_WINDOW").ok()
.and_then(|s| s.parse().ok()).unwrap_or(128);
let rep_penalty: f32 = std::env::var("XSERV_REP_PENALTY")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(1.0);
let rep_window: usize = std::env::var("XSERV_REP_WINDOW")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(128);
let pick = |logits: &Tensor, sp: &xserv_model::SamplingParams, history: &[u32]| -> u32 {
if rep_penalty > 1.0 && sp.temperature == 0.0 {
let start = history.len().saturating_sub(rep_window);
@@ -197,8 +277,16 @@ pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
wait_acks(&ack_rx);
// Prefill.
broadcast(&cmd_txs, TpCommand::Prefill { tokens: req.prompt_tokens.clone(), slot });
let logits = rc.model.forward_prefill_paged(&req.prompt_tokens, slot, &mut rc.cache);
broadcast(
&cmd_txs,
TpCommand::Prefill {
tokens: req.prompt_tokens.clone(),
slot,
},
);
let logits = rc
.model
.forward_prefill_paged(&req.prompt_tokens, slot, &mut rc.cache);
wait_acks(&ack_rx);
let mut gen_ids: Vec<u32> = Vec::new();
let mut next = pick(&logits, &req.sampling, &gen_ids);
@@ -216,7 +304,14 @@ pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
break "length";
}
let pos = rc.cache.seq_len(slot);
broadcast(&cmd_txs, TpCommand::Decode { tokens: vec![next], positions: vec![pos], slots: vec![slot] });
broadcast(
&cmd_txs,
TpCommand::Decode {
tokens: vec![next],
positions: vec![pos],
slots: vec![slot],
},
);
let logits = rank_decode(&mut rc, &[next], &[pos], &[slot]);
wait_acks(&ack_rx);
next = pick(&logits, &req.sampling, &gen_ids);
@@ -227,9 +322,14 @@ pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
let tail = tokenizer.flush_decode_stream(&mut decode_buf);
if !tail.is_empty() {
let _ = req.sender.blocking_send(GenerateEvent::Token { id: next, text: tail });
let _ = req.sender.blocking_send(GenerateEvent::Token {
id: next,
text: tail,
});
}
let _ = req.sender.blocking_send(GenerateEvent::Done { finish_reason: finish.to_string() });
let _ = req.sender.blocking_send(GenerateEvent::Done {
finish_reason: finish.to_string(),
});
broadcast(&cmd_txs, TpCommand::Free(slot));
rc.cache.free_sequence(slot);
@@ -246,6 +346,8 @@ fn emit_text(tokenizer: &Tokenizer, req: &GenerateRequest, token_id: u32, buf: &
}
let text = tokenizer.decode_token_stream(token_id, buf);
if !text.is_empty() {
let _ = req.sender.blocking_send(GenerateEvent::Token { id: token_id, text });
let _ = req
.sender
.blocking_send(GenerateEvent::Token { id: token_id, text });
}
}

View File

@@ -43,18 +43,30 @@ pub trait TensorDType: Copy + Send + Sync + 'static {
impl TensorDType for f32 {
const DTYPE: DType = DType::F32;
fn to_f64(self) -> f64 { self as f64 }
fn from_f64(v: f64) -> Self { v as f32 }
fn to_f64(self) -> f64 {
self as f64
}
fn from_f64(v: f64) -> Self {
v as f32
}
}
impl TensorDType for f16 {
const DTYPE: DType = DType::F16;
fn to_f64(self) -> f64 { self.to_f32() as f64 }
fn from_f64(v: f64) -> Self { f16::from_f32(v as f32) }
fn to_f64(self) -> f64 {
self.to_f32() as f64
}
fn from_f64(v: f64) -> Self {
f16::from_f32(v as f32)
}
}
impl TensorDType for bf16 {
const DTYPE: DType = DType::BF16;
fn to_f64(self) -> f64 { self.to_f32() as f64 }
fn from_f64(v: f64) -> Self { bf16::from_f32(v as f32) }
fn to_f64(self) -> f64 {
self.to_f32() as f64
}
fn from_f64(v: f64) -> Self {
bf16::from_f32(v as f32)
}
}

View File

@@ -6,4 +6,4 @@ pub mod tensor;
pub use dtype::{DType, TensorDType};
pub use shape::Dims;
pub use storage::{Device, Storage};
pub use tensor::{register_gpu_contiguous, Tensor};
pub use tensor::{Tensor, register_gpu_contiguous};

View File

@@ -46,8 +46,16 @@ pub fn broadcast_shape(a: &[usize], b: &[usize]) -> Option<Dims> {
let ndim = a.len().max(b.len());
let mut result = SmallVec::with_capacity(ndim);
for i in 0..ndim {
let da = if i < ndim - a.len() { 1 } else { a[i - (ndim - a.len())] };
let db = if i < ndim - b.len() { 1 } else { b[i - (ndim - b.len())] };
let da = if i < ndim - a.len() {
1
} else {
a[i - (ndim - a.len())]
};
let db = if i < ndim - b.len() {
1
} else {
b[i - (ndim - b.len())]
};
if da == db {
result.push(da);
} else if da == 1 {
@@ -100,8 +108,14 @@ mod tests {
#[test]
fn test_broadcast_shape() {
assert_eq!(broadcast_shape(&[3, 1], &[1, 4]).unwrap().as_slice(), &[3, 4]);
assert_eq!(broadcast_shape(&[2, 3, 4], &[4]).unwrap().as_slice(), &[2, 3, 4]);
assert_eq!(
broadcast_shape(&[3, 1], &[1, 4]).unwrap().as_slice(),
&[3, 4]
);
assert_eq!(
broadcast_shape(&[2, 3, 4], &[4]).unwrap().as_slice(),
&[2, 3, 4]
);
assert_eq!(broadcast_shape(&[1], &[5, 3]).unwrap().as_slice(), &[5, 3]);
assert!(broadcast_shape(&[3], &[4]).is_none());
}
@@ -109,6 +123,9 @@ mod tests {
#[test]
fn test_broadcast_strides() {
// [3,1] with strides [1,1] broadcast to [3,4]
assert_eq!(broadcast_strides(&[3, 1], &[1, 1], &[3, 4]).as_slice(), &[1, 0]);
assert_eq!(
broadcast_strides(&[3, 1], &[1, 1], &[3, 4]).as_slice(),
&[1, 0]
);
}
}

View File

@@ -33,8 +33,20 @@ impl Tensor {
// --- Creation ---
/// Create a tensor from raw components (for advanced use like GPU KV cache).
pub fn from_storage(storage: Storage, shape: Dims, strides: Dims, offset: usize, dtype: DType) -> Self {
Self { storage, shape, strides, offset, dtype }
pub fn from_storage(
storage: Storage,
shape: Dims,
strides: Dims,
offset: usize,
dtype: DType,
) -> Self {
Self {
storage,
shape,
strides,
offset,
dtype,
}
}
pub fn from_slice<T: TensorDType>(data: &[T], shape: &[usize]) -> Self {
@@ -60,7 +72,10 @@ impl Tensor {
data.len(),
numel * dtype.size_bytes(),
"raw bytes length {} != expected {} (numel={} * elem_size={})",
data.len(), numel * dtype.size_bytes(), numel, dtype.size_bytes()
data.len(),
numel * dtype.size_bytes(),
numel,
dtype.size_bytes()
);
Self {
storage: Storage::cpu(data.to_vec()),
@@ -112,14 +127,28 @@ impl Tensor {
// --- Properties ---
pub fn shape(&self) -> &[usize] { &self.shape }
pub fn strides(&self) -> &[usize] { &self.strides }
pub fn dtype(&self) -> DType { self.dtype }
pub fn ndim(&self) -> usize { self.shape.len() }
pub fn numel(&self) -> usize { shape::num_elements(&self.shape) }
pub fn offset(&self) -> usize { self.offset }
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn strides(&self) -> &[usize] {
&self.strides
}
pub fn dtype(&self) -> DType {
self.dtype
}
pub fn ndim(&self) -> usize {
self.shape.len()
}
pub fn numel(&self) -> usize {
shape::num_elements(&self.shape)
}
pub fn offset(&self) -> usize {
self.offset
}
pub fn device(&self) -> Device { self.storage.device() }
pub fn device(&self) -> Device {
self.storage.device()
}
pub fn is_contiguous(&self) -> bool {
shape::is_contiguous(&self.shape, &self.strides)
@@ -193,7 +222,11 @@ impl Tensor {
shape::contiguous_strides(&new_shape)
} else {
let mut s = self.strides.clone();
let stride_val = if dim < self.strides.len() { self.strides[dim] } else { 1 };
let stride_val = if dim < self.strides.len() {
self.strides[dim]
} else {
1
};
s.insert(dim, stride_val);
s
};
@@ -230,7 +263,12 @@ impl Tensor {
let ndim = self.ndim();
let mut idx = vec![0usize; ndim];
for flat in 0..numel {
let src_offset = self.offset + idx.iter().zip(self.strides.iter()).map(|(i, s)| i * s).sum::<usize>();
let src_offset = self.offset
+ idx
.iter()
.zip(self.strides.iter())
.map(|(i, s)| i * s)
.sum::<usize>();
let src_byte_offset = src_offset * elem_size;
let dst_byte_offset = flat * elem_size;
dst[dst_byte_offset..dst_byte_offset + elem_size]
@@ -261,7 +299,10 @@ impl Tensor {
}
// Transfer the raw storage (preserving strides/offset).
// Non-contiguous layout is preserved — the user can call contiguous() after.
let new_storage = self.storage.to_device(device).expect("device transfer failed");
let new_storage = self
.storage
.to_device(device)
.expect("device transfer failed");
Self {
storage: new_storage,
shape: self.shape.clone(),
@@ -310,14 +351,20 @@ impl Tensor {
}
}
pub fn storage(&self) -> &Storage { &self.storage }
pub fn storage(&self) -> &Storage {
&self.storage
}
}
impl std::fmt::Debug for Tensor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f, "Tensor(shape={:?}, dtype={}, device={}, contiguous={})",
self.shape.as_slice(), self.dtype, self.device(), self.is_contiguous()
f,
"Tensor(shape={:?}, dtype={}, device={}, contiguous={})",
self.shape.as_slice(),
self.dtype,
self.device(),
self.is_contiguous()
)
}
}

View File

@@ -32,7 +32,11 @@ fn test_zeros_and_ones() {
#[test]
fn test_bf16_tensor() {
let data: Vec<bf16> = vec![bf16::from_f32(1.0), bf16::from_f32(2.5), bf16::from_f32(-3.0)];
let data: Vec<bf16> = vec![
bf16::from_f32(1.0),
bf16::from_f32(2.5),
bf16::from_f32(-3.0),
];
let t = Tensor::from_slice(&data, &[3]);
assert_eq!(t.dtype(), DType::BF16);
let out = t.as_slice::<bf16>();

View File

@@ -95,11 +95,15 @@ impl Tokenizer {
let (a_str, b_str) = match entry {
MergeEntry::Str(s) => {
let parts: Vec<&str> = s.splitn(2, ' ').collect();
if parts.len() != 2 { continue; }
if parts.len() != 2 {
continue;
}
(parts[0].to_string(), parts[1].to_string())
}
MergeEntry::Pair(v) => {
if v.len() != 2 { continue; }
if v.len() != 2 {
continue;
}
(v[0].clone(), v[1].clone())
}
};
@@ -174,7 +178,10 @@ impl Tokenizer {
if byte_fallback {
Regex::new(r"[\p{L}\p{N}]+|[^\s\p{L}\p{N}]|\s+").unwrap()
} else {
Regex::new(r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+").unwrap()
Regex::new(
r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+",
)
.unwrap()
}
}
}
@@ -262,7 +269,9 @@ impl Tokenizer {
// BPE merges
loop {
if token_ids.len() < 2 { break; }
if token_ids.len() < 2 {
break;
}
let mut best_rank = usize::MAX;
let mut best_idx = 0;
for i in 0..token_ids.len() - 1 {
@@ -273,12 +282,15 @@ impl Tokenizer {
}
}
}
if best_rank == usize::MAX { break; }
if best_rank == usize::MAX {
break;
}
let merged_bytes = [
self.decoder[token_ids[best_idx] as usize].as_slice(),
self.decoder[token_ids[best_idx + 1] as usize].as_slice(),
].concat();
]
.concat();
let merged_id = *self.encoder.get(&merged_bytes).unwrap_or_else(|| {
panic!("merged token not in vocab");
});
@@ -389,14 +401,13 @@ fn unicode_to_byte(c: char) -> u8 {
m
});
*map.get(&(c as u32)).unwrap_or_else(|| {
panic!("unmapped unicode char U+{:04X} in tokenizer", c as u32)
})
*map.get(&(c as u32))
.unwrap_or_else(|| panic!("unmapped unicode char U+{:04X} in tokenizer", c as u32))
}
#[cfg(test)]
mod tests {
use super::{take_valid_utf8, Tokenizer};
use super::{Tokenizer, take_valid_utf8};
#[test]
fn qwen_added_tokens_are_indivisible_and_im_end_is_eos() {