style: format Rust workspace
This commit is contained in:
@@ -165,7 +165,10 @@ pub fn begin_retain() {
|
||||
/// Stop quarantining and hand the quarantined blocks to the caller.
|
||||
pub fn end_retain() -> RetainedBlocks {
|
||||
RETAINED.with(|cell| {
|
||||
let list = cell.borrow_mut().take().expect("end_retain without begin_retain");
|
||||
let list = cell
|
||||
.borrow_mut()
|
||||
.take()
|
||||
.expect("end_retain without begin_retain");
|
||||
RetainedBlocks(list)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -48,9 +48,7 @@ pub fn device_info(device: u32) -> Result<DeviceInfo> {
|
||||
// Heap-allocate oversized buffer for cudaDeviceProp (layout varies by CUDA version).
|
||||
// CUDA 12.x struct is ~5-6 KB; use 32 KB to guard against future growth.
|
||||
let mut prop_buf = vec![0u8; 32768];
|
||||
error::check(unsafe {
|
||||
ffi::cudaGetDeviceProperties(prop_buf.as_mut_ptr(), device as i32)
|
||||
})?;
|
||||
error::check(unsafe { ffi::cudaGetDeviceProperties(prop_buf.as_mut_ptr(), device as i32) })?;
|
||||
// Name is always the first field: char[256].
|
||||
let name = unsafe { CStr::from_ptr(prop_buf.as_ptr() as *const c_char) }
|
||||
.to_string_lossy()
|
||||
|
||||
@@ -64,11 +64,5 @@ unsafe extern "C" {
|
||||
pub fn cudaGraphExecDestroy(graph_exec: CudaGraphExec) -> i32;
|
||||
|
||||
// --- Our test kernel ---
|
||||
pub fn launch_vecadd_f32(
|
||||
a: *const f32,
|
||||
b: *const f32,
|
||||
c: *mut f32,
|
||||
n: i32,
|
||||
stream: CudaStream,
|
||||
);
|
||||
pub fn launch_vecadd_f32(a: *const f32, b: *const f32, c: *mut f32, n: i32, stream: CudaStream);
|
||||
}
|
||||
|
||||
@@ -54,30 +54,21 @@ impl CudaGraph {
|
||||
// made by THIS thread invalidate the capture. With GLOBAL mode, TP rank
|
||||
// threads capturing concurrently would poison each other's captures.
|
||||
error::check(unsafe {
|
||||
ffi::cudaStreamBeginCapture(
|
||||
stream.as_raw(),
|
||||
ffi::CUDA_STREAM_CAPTURE_MODE_THREAD_LOCAL,
|
||||
)
|
||||
ffi::cudaStreamBeginCapture(stream.as_raw(), ffi::CUDA_STREAM_CAPTURE_MODE_THREAD_LOCAL)
|
||||
})
|
||||
}
|
||||
|
||||
/// End capture and instantiate the executable graph.
|
||||
pub fn end_capture(&mut self, stream: &CudaStream) -> Result<()> {
|
||||
error::check(unsafe {
|
||||
ffi::cudaStreamEndCapture(stream.as_raw(), &mut self.graph)
|
||||
})?;
|
||||
error::check(unsafe {
|
||||
ffi::cudaGraphInstantiate(&mut self.exec, self.graph, 0)
|
||||
})
|
||||
error::check(unsafe { ffi::cudaStreamEndCapture(stream.as_raw(), &mut self.graph) })?;
|
||||
error::check(unsafe { ffi::cudaGraphInstantiate(&mut self.exec, self.graph, 0) })
|
||||
}
|
||||
|
||||
/// Replay the captured graph on `stream`.
|
||||
/// Panics if no graph has been captured yet.
|
||||
pub fn launch(&self, stream: &CudaStream) -> Result<()> {
|
||||
assert!(self.is_ready(), "CudaGraph::launch called before capture");
|
||||
error::check(unsafe {
|
||||
ffi::cudaGraphLaunch(self.exec, stream.as_raw())
|
||||
})
|
||||
error::check(unsafe { ffi::cudaGraphLaunch(self.exec, stream.as_raw()) })
|
||||
}
|
||||
|
||||
fn destroy_inner(&mut self) {
|
||||
|
||||
@@ -11,4 +11,4 @@ pub use device::DeviceInfo;
|
||||
pub use error::{CudaError, Result};
|
||||
pub use graph::CudaGraph;
|
||||
pub use memory::{GpuBuffer, PinnedBuffer};
|
||||
pub use stream::{current_stream_raw, push_stream, CudaStream, StreamGuard};
|
||||
pub use stream::{CudaStream, StreamGuard, current_stream_raw, push_stream};
|
||||
|
||||
@@ -22,7 +22,12 @@ impl GpuBuffer {
|
||||
assert!(len > 0, "cannot allocate 0 bytes on GPU");
|
||||
let mut ptr = std::ptr::null_mut();
|
||||
error::check(unsafe { ffi::cudaMalloc(&mut ptr, len) })?;
|
||||
Ok(Self { ptr, len, owned: true, pooled: false })
|
||||
Ok(Self {
|
||||
ptr,
|
||||
len,
|
||||
owned: true,
|
||||
pooled: false,
|
||||
})
|
||||
}
|
||||
|
||||
/// Mark this buffer as pooled (returned to caching allocator on drop)
|
||||
@@ -92,9 +97,7 @@ impl GpuBuffer {
|
||||
/// Copy from another GPU buffer (D2D).
|
||||
pub fn copy_from_device(&mut self, src: &GpuBuffer) -> Result<()> {
|
||||
let n = src.len.min(self.len);
|
||||
error::check(unsafe {
|
||||
ffi::cudaMemcpy(self.ptr, src.ptr, n, ffi::CUDA_MEMCPY_D2D)
|
||||
})
|
||||
error::check(unsafe { ffi::cudaMemcpy(self.ptr, src.ptr, n, ffi::CUDA_MEMCPY_D2D) })
|
||||
}
|
||||
|
||||
/// Fill buffer with zeros.
|
||||
@@ -103,7 +106,13 @@ impl GpuBuffer {
|
||||
}
|
||||
|
||||
/// Copy `count` bytes from `src` buffer at `src_offset` to this buffer at `dst_offset`.
|
||||
pub fn copy_from_device_at(&mut self, src: &GpuBuffer, src_offset: usize, dst_offset: usize, count: usize) -> Result<()> {
|
||||
pub fn copy_from_device_at(
|
||||
&mut self,
|
||||
src: &GpuBuffer,
|
||||
src_offset: usize,
|
||||
dst_offset: usize,
|
||||
count: usize,
|
||||
) -> Result<()> {
|
||||
assert!(src_offset + count <= src.len);
|
||||
assert!(dst_offset + count <= self.len);
|
||||
error::check(unsafe {
|
||||
@@ -117,7 +126,14 @@ impl GpuBuffer {
|
||||
}
|
||||
|
||||
/// Async copy `count` bytes from `src` at `src_offset` to `self` at `dst_offset` on `stream`.
|
||||
pub fn copy_from_device_at_async(&mut self, src: &GpuBuffer, src_offset: usize, dst_offset: usize, count: usize, stream: &CudaStream) -> Result<()> {
|
||||
pub fn copy_from_device_at_async(
|
||||
&mut self,
|
||||
src: &GpuBuffer,
|
||||
src_offset: usize,
|
||||
dst_offset: usize,
|
||||
count: usize,
|
||||
stream: &CudaStream,
|
||||
) -> Result<()> {
|
||||
assert!(src_offset + count <= src.len);
|
||||
assert!(dst_offset + count <= self.len);
|
||||
error::check(unsafe {
|
||||
@@ -161,9 +177,7 @@ impl GpuBuffer {
|
||||
|
||||
/// Async zero fill on stream.
|
||||
pub fn zero_async(&mut self, stream: &CudaStream) -> Result<()> {
|
||||
error::check(unsafe {
|
||||
ffi::cudaMemsetAsync(self.ptr, 0, self.len, stream.as_raw())
|
||||
})
|
||||
error::check(unsafe { ffi::cudaMemsetAsync(self.ptr, 0, self.len, stream.as_raw()) })
|
||||
}
|
||||
|
||||
/// Consume the buffer without freeing GPU memory. Returns the raw pointer and length.
|
||||
@@ -178,7 +192,12 @@ impl GpuBuffer {
|
||||
/// Reconstruct a GpuBuffer from a raw pointer + length.
|
||||
/// Safety: ptr must have been allocated with cudaMalloc, len must be correct.
|
||||
pub unsafe fn from_raw(ptr: *mut u8, len: usize) -> Self {
|
||||
Self { ptr, len, owned: true, pooled: false }
|
||||
Self {
|
||||
ptr,
|
||||
len,
|
||||
owned: true,
|
||||
pooled: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a non-owning view of GPU memory. Dropping this buffer does NOT
|
||||
@@ -189,7 +208,12 @@ impl GpuBuffer {
|
||||
/// `ptr` must point to a valid GPU allocation of at least `len` bytes that
|
||||
/// will remain live for the lifetime of the returned `GpuBuffer`.
|
||||
pub unsafe fn borrow_raw(ptr: *mut u8, len: usize) -> Self {
|
||||
Self { ptr, len, owned: false, pooled: false }
|
||||
Self {
|
||||
ptr,
|
||||
len,
|
||||
owned: false,
|
||||
pooled: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,7 +14,10 @@ fn test_device_info() {
|
||||
info.compute_major, info.compute_minor
|
||||
);
|
||||
println!(" SM Count: {}", info.sm_count);
|
||||
println!(" Shared Mem/Block: {} KB", info.shared_mem_per_block / 1024);
|
||||
println!(
|
||||
" Shared Mem/Block: {} KB",
|
||||
info.shared_mem_per_block / 1024
|
||||
);
|
||||
println!(" Warp Size: {}", info.warp_size);
|
||||
println!(" Max Threads/Block: {}", info.max_threads_per_block);
|
||||
|
||||
@@ -145,7 +148,11 @@ fn test_caching_allocator() {
|
||||
|
||||
// Second allocation of same size: should hit cache
|
||||
let _buf2 = alloc.alloc(1024).unwrap();
|
||||
assert_eq!(alloc.stats().cuda_malloc_count, 1, "should reuse cached buffer");
|
||||
assert_eq!(
|
||||
alloc.stats().cuda_malloc_count,
|
||||
1,
|
||||
"should reuse cached buffer"
|
||||
);
|
||||
assert_eq!(alloc.stats().cache_hit_count, 1);
|
||||
}
|
||||
|
||||
@@ -198,11 +205,17 @@ fn test_async_copy() {
|
||||
}
|
||||
|
||||
let mut gpu = GpuBuffer::alloc(4096).unwrap();
|
||||
unsafe { gpu.copy_from_host_async(pinned.as_slice(), &stream).unwrap() };
|
||||
unsafe {
|
||||
gpu.copy_from_host_async(pinned.as_slice(), &stream)
|
||||
.unwrap()
|
||||
};
|
||||
stream.synchronize().unwrap();
|
||||
|
||||
let mut out_pinned = PinnedBuffer::alloc(4096).unwrap();
|
||||
unsafe { gpu.copy_to_host_async(out_pinned.as_mut_slice(), &stream).unwrap() };
|
||||
unsafe {
|
||||
gpu.copy_to_host_async(out_pinned.as_mut_slice(), &stream)
|
||||
.unwrap()
|
||||
};
|
||||
stream.synchronize().unwrap();
|
||||
|
||||
assert_eq!(pinned.as_slice(), out_pinned.as_slice());
|
||||
|
||||
@@ -34,7 +34,12 @@ pub const NCCL_SUCCESS: i32 = 0;
|
||||
unsafe extern "C" {
|
||||
pub fn ncclGetUniqueId(uid: *mut NcclUniqueId) -> i32;
|
||||
// ncclUniqueId is passed BY VALUE (a 128-byte struct) per the NCCL ABI.
|
||||
pub fn ncclCommInitRank(comm: *mut NcclComm, nranks: i32, commid: NcclUniqueId, rank: i32) -> i32;
|
||||
pub fn ncclCommInitRank(
|
||||
comm: *mut NcclComm,
|
||||
nranks: i32,
|
||||
commid: NcclUniqueId,
|
||||
rank: i32,
|
||||
) -> i32;
|
||||
pub fn ncclCommDestroy(comm: NcclComm) -> i32;
|
||||
pub fn ncclAllReduce(
|
||||
sendbuff: *const c_void,
|
||||
@@ -78,5 +83,10 @@ pub fn err_string(result: i32) -> String {
|
||||
}
|
||||
|
||||
pub fn check(result: i32, what: &str) {
|
||||
assert_eq!(result, NCCL_SUCCESS, "{what} failed: {}", err_string(result));
|
||||
assert_eq!(
|
||||
result,
|
||||
NCCL_SUCCESS,
|
||||
"{what} failed: {}",
|
||||
err_string(result)
|
||||
);
|
||||
}
|
||||
|
||||
@@ -9,8 +9,8 @@ pub mod ffi;
|
||||
use std::ffi::c_void;
|
||||
|
||||
use ffi::{NcclComm, NcclUniqueId};
|
||||
use xserv_cuda::device;
|
||||
use xserv_cuda::GpuBuffer;
|
||||
use xserv_cuda::device;
|
||||
|
||||
pub use ffi::NcclUniqueId as UniqueId;
|
||||
|
||||
@@ -55,7 +55,12 @@ impl TpContext {
|
||||
"ncclCommInitRank",
|
||||
);
|
||||
ffi::check(unsafe { ffi::ncclGroupEnd() }, "ncclGroupEnd(init)");
|
||||
Self { rank, world, device, comm }
|
||||
Self {
|
||||
rank,
|
||||
world,
|
||||
device,
|
||||
comm,
|
||||
}
|
||||
}
|
||||
|
||||
/// In-place AllReduce(sum) over `count` BF16 elements in `buf`.
|
||||
@@ -127,7 +132,12 @@ impl PpContext {
|
||||
"ncclCommInitRank",
|
||||
);
|
||||
ffi::check(unsafe { ffi::ncclGroupEnd() }, "ncclGroupEnd(init)");
|
||||
Self { stage, world, device, comm }
|
||||
Self {
|
||||
stage,
|
||||
world,
|
||||
device,
|
||||
comm,
|
||||
}
|
||||
}
|
||||
|
||||
/// Send `count` BF16 elements at `ptr` to `peer`, on the null stream so it is
|
||||
@@ -138,7 +148,16 @@ impl PpContext {
|
||||
/// `ptr` must point to at least `count` BF16 elements of valid device memory.
|
||||
pub fn send_bf16_ptr(&self, ptr: *const c_void, count: usize, peer: usize) {
|
||||
ffi::check(
|
||||
unsafe { ffi::ncclSend(ptr, count, ffi::NCCL_BF16, peer as i32, self.comm, launch_stream()) },
|
||||
unsafe {
|
||||
ffi::ncclSend(
|
||||
ptr,
|
||||
count,
|
||||
ffi::NCCL_BF16,
|
||||
peer as i32,
|
||||
self.comm,
|
||||
launch_stream(),
|
||||
)
|
||||
},
|
||||
"ncclSend",
|
||||
);
|
||||
}
|
||||
@@ -149,7 +168,16 @@ impl PpContext {
|
||||
/// `ptr` must point to at least `count` BF16 elements of valid device memory.
|
||||
pub fn recv_bf16_ptr(&self, ptr: *mut c_void, count: usize, peer: usize) {
|
||||
ffi::check(
|
||||
unsafe { ffi::ncclRecv(ptr, count, ffi::NCCL_BF16, peer as i32, self.comm, launch_stream()) },
|
||||
unsafe {
|
||||
ffi::ncclRecv(
|
||||
ptr,
|
||||
count,
|
||||
ffi::NCCL_BF16,
|
||||
peer as i32,
|
||||
self.comm,
|
||||
launch_stream(),
|
||||
)
|
||||
},
|
||||
"ncclRecv",
|
||||
);
|
||||
}
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
|
||||
use half::bf16;
|
||||
use std::thread;
|
||||
use xserv_cuda::{device, GpuBuffer};
|
||||
use xserv_distributed::{get_unique_id, TpContext};
|
||||
use xserv_cuda::{GpuBuffer, device};
|
||||
use xserv_distributed::{TpContext, get_unique_id};
|
||||
|
||||
#[test]
|
||||
fn allreduce_two_gpu_sum() {
|
||||
@@ -25,9 +25,7 @@ fn allreduce_two_gpu_sum() {
|
||||
// Rank r fills its buffer with (r + 1).
|
||||
let val = bf16::from_f32((rank + 1) as f32);
|
||||
let host = vec![val; n];
|
||||
let src = unsafe {
|
||||
std::slice::from_raw_parts(host.as_ptr() as *const u8, n * 2)
|
||||
};
|
||||
let src = unsafe { std::slice::from_raw_parts(host.as_ptr() as *const u8, n * 2) };
|
||||
let mut buf = GpuBuffer::alloc(n * 2).unwrap();
|
||||
buf.copy_from_host(src).unwrap();
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@
|
||||
use half::bf16;
|
||||
use std::ffi::c_void;
|
||||
use std::thread;
|
||||
use xserv_cuda::{device, GpuBuffer};
|
||||
use xserv_distributed::{get_unique_id, PpContext};
|
||||
use xserv_cuda::{GpuBuffer, device};
|
||||
use xserv_distributed::{PpContext, get_unique_id};
|
||||
|
||||
#[test]
|
||||
fn pp_send_recv_two_stages() {
|
||||
@@ -30,7 +30,8 @@ fn pp_send_recv_two_stages() {
|
||||
if stage == 0 {
|
||||
// Fill with a known pattern and send to stage 1.
|
||||
let host: Vec<bf16> = (0..n).map(|i| bf16::from_f32((i % 97) as f32)).collect();
|
||||
let src = unsafe { std::slice::from_raw_parts(host.as_ptr() as *const u8, n * 2) };
|
||||
let src =
|
||||
unsafe { std::slice::from_raw_parts(host.as_ptr() as *const u8, n * 2) };
|
||||
buf.copy_from_host(src).unwrap();
|
||||
pp.send_bf16_ptr(buf.as_mut_ptr() as *const c_void, n, 1);
|
||||
device::synchronize().unwrap();
|
||||
|
||||
@@ -6,78 +6,189 @@ unsafe extern "C" {
|
||||
fn launch_gelu_bf16(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_silu_f32(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_silu_bf16(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_scale_f32(x: *const c_void, out: *mut c_void, scale: f32, n: i32, stream: *mut c_void);
|
||||
fn launch_scale_bf16(x: *const c_void, out: *mut c_void, scale: f32, n: i32, stream: *mut c_void);
|
||||
fn launch_add_f32(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_add_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_mul_f32(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_mul_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_silu_mul_bf16(gate: *const c_void, up: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_gpt_oss_glu_bf16(gate_up: *const c_void, out: *mut c_void, n_elements: i32,
|
||||
alpha: f32, limit: f32, stream: *mut c_void);
|
||||
fn launch_bias_add_2d_bf16(x: *const c_void, bias: *const c_void, out: *mut c_void,
|
||||
rows: i32, cols: i32, stream: *mut c_void);
|
||||
fn launch_scale_f32(
|
||||
x: *const c_void,
|
||||
out: *mut c_void,
|
||||
scale: f32,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_scale_bf16(
|
||||
x: *const c_void,
|
||||
out: *mut c_void,
|
||||
scale: f32,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_add_f32(
|
||||
a: *const c_void,
|
||||
b: *const c_void,
|
||||
out: *mut c_void,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_add_bf16(
|
||||
a: *const c_void,
|
||||
b: *const c_void,
|
||||
out: *mut c_void,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_mul_f32(
|
||||
a: *const c_void,
|
||||
b: *const c_void,
|
||||
out: *mut c_void,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_mul_bf16(
|
||||
a: *const c_void,
|
||||
b: *const c_void,
|
||||
out: *mut c_void,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_silu_mul_bf16(
|
||||
gate: *const c_void,
|
||||
up: *const c_void,
|
||||
out: *mut c_void,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_gpt_oss_glu_bf16(
|
||||
gate_up: *const c_void,
|
||||
out: *mut c_void,
|
||||
n_elements: i32,
|
||||
alpha: f32,
|
||||
limit: f32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_bias_add_2d_bf16(
|
||||
x: *const c_void,
|
||||
bias: *const c_void,
|
||||
out: *mut c_void,
|
||||
rows: i32,
|
||||
cols: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
fn dispatch_unary(x: &Tensor, f32_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void),
|
||||
bf16_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void)) -> Tensor {
|
||||
fn dispatch_unary(
|
||||
x: &Tensor,
|
||||
f32_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void),
|
||||
bf16_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void),
|
||||
) -> Tensor {
|
||||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||||
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
||||
let n = x.numel();
|
||||
assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)");
|
||||
assert!(
|
||||
n <= i32::MAX as usize,
|
||||
"tensor too large for i32 kernel param ({n} elements)"
|
||||
);
|
||||
let n = n as i32;
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => f32_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()),
|
||||
DType::BF16 => bf16_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()),
|
||||
DType::F32 => f32_fn(
|
||||
x.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
n,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::BF16 => bf16_fn(
|
||||
x.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
n,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!("unsupported dtype"),
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn dispatch_binary(a: &Tensor, b: &Tensor,
|
||||
fn dispatch_binary(
|
||||
a: &Tensor,
|
||||
b: &Tensor,
|
||||
f32_fn: unsafe extern "C" fn(*const c_void, *const c_void, *mut c_void, i32, *mut c_void),
|
||||
bf16_fn: unsafe extern "C" fn(*const c_void, *const c_void, *mut c_void, i32, *mut c_void)) -> Tensor {
|
||||
bf16_fn: unsafe extern "C" fn(*const c_void, *const c_void, *mut c_void, i32, *mut c_void),
|
||||
) -> Tensor {
|
||||
assert_eq!(a.shape(), b.shape());
|
||||
assert!(a.is_contiguous() && b.is_contiguous());
|
||||
assert!(matches!(a.device(), Device::Cuda(_)));
|
||||
assert_eq!(a.dtype(), b.dtype());
|
||||
let out = Tensor::empty(a.shape(), a.dtype(), a.device());
|
||||
let n = a.numel();
|
||||
assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)");
|
||||
assert!(
|
||||
n <= i32::MAX as usize,
|
||||
"tensor too large for i32 kernel param ({n} elements)"
|
||||
);
|
||||
let n = n as i32;
|
||||
unsafe {
|
||||
match a.dtype() {
|
||||
DType::F32 => f32_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()),
|
||||
DType::BF16 => bf16_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()),
|
||||
DType::F32 => f32_fn(
|
||||
a.data_ptr() as _,
|
||||
b.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
n,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::BF16 => bf16_fn(
|
||||
a.data_ptr() as _,
|
||||
b.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
n,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!("unsupported dtype"),
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
pub fn gelu(x: &Tensor) -> Tensor { dispatch_unary(x, launch_gelu_f32, launch_gelu_bf16) }
|
||||
pub fn silu(x: &Tensor) -> Tensor { dispatch_unary(x, launch_silu_f32, launch_silu_bf16) }
|
||||
pub fn gelu(x: &Tensor) -> Tensor {
|
||||
dispatch_unary(x, launch_gelu_f32, launch_gelu_bf16)
|
||||
}
|
||||
pub fn silu(x: &Tensor) -> Tensor {
|
||||
dispatch_unary(x, launch_silu_f32, launch_silu_bf16)
|
||||
}
|
||||
|
||||
pub fn scale(x: &Tensor, scale_val: f32) -> Tensor {
|
||||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||||
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
||||
let n = x.numel();
|
||||
assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)");
|
||||
assert!(
|
||||
n <= i32::MAX as usize,
|
||||
"tensor too large for i32 kernel param ({n} elements)"
|
||||
);
|
||||
let n = n as i32;
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => launch_scale_f32(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, xserv_cuda::current_stream_raw()),
|
||||
DType::BF16 => launch_scale_bf16(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, xserv_cuda::current_stream_raw()),
|
||||
DType::F32 => launch_scale_f32(
|
||||
x.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
scale_val,
|
||||
n,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::BF16 => launch_scale_bf16(
|
||||
x.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
scale_val,
|
||||
n,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for scale"),
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
pub fn add(a: &Tensor, b: &Tensor) -> Tensor { dispatch_binary(a, b, launch_add_f32, launch_add_bf16) }
|
||||
pub fn mul(a: &Tensor, b: &Tensor) -> Tensor { dispatch_binary(a, b, launch_mul_f32, launch_mul_bf16) }
|
||||
pub fn add(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
dispatch_binary(a, b, launch_add_f32, launch_add_bf16)
|
||||
}
|
||||
pub fn mul(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
dispatch_binary(a, b, launch_mul_f32, launch_mul_bf16)
|
||||
}
|
||||
|
||||
/// Row-broadcast bias add: out[r, c] = x[r, c] + bias[c] (BF16 only).
|
||||
pub fn bias_add_2d(x: &Tensor, bias: &Tensor) -> Tensor {
|
||||
@@ -89,13 +200,22 @@ pub fn bias_add_2d(x: &Tensor, bias: &Tensor) -> Tensor {
|
||||
assert!(matches!(x.device(), Device::Cuda(_)));
|
||||
let rows = x.shape()[0];
|
||||
let cols = x.shape()[1];
|
||||
assert_eq!(bias.shape()[0], cols, "bias size {} != cols {cols}", bias.shape()[0]);
|
||||
assert_eq!(
|
||||
bias.shape()[0],
|
||||
cols,
|
||||
"bias size {} != cols {cols}",
|
||||
bias.shape()[0]
|
||||
);
|
||||
assert!(rows * cols <= i32::MAX as usize);
|
||||
let out = Tensor::empty(&[rows, cols], DType::BF16, x.device());
|
||||
unsafe {
|
||||
launch_bias_add_2d_bf16(
|
||||
x.data_ptr() as _, bias.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
rows as i32, cols as i32, xserv_cuda::current_stream_raw(),
|
||||
x.data_ptr() as _,
|
||||
bias.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
out
|
||||
@@ -110,7 +230,10 @@ pub fn silu_mul(gate: &Tensor, up: &Tensor) -> Tensor {
|
||||
assert_eq!(gate.dtype(), DType::BF16, "silu_mul requires BF16");
|
||||
let out = Tensor::empty(gate.shape(), gate.dtype(), gate.device());
|
||||
let n = gate.numel();
|
||||
assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)");
|
||||
assert!(
|
||||
n <= i32::MAX as usize,
|
||||
"tensor too large for i32 kernel param ({n} elements)"
|
||||
);
|
||||
let n = n as i32;
|
||||
unsafe {
|
||||
launch_silu_mul_bf16(
|
||||
|
||||
@@ -2,8 +2,13 @@ use std::ffi::c_void;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_argmax_bf16(logits: *const c_void, out_idx: *mut c_void,
|
||||
rows: i32, cols: i32, stream: *mut c_void);
|
||||
fn launch_argmax_bf16(
|
||||
logits: *const c_void,
|
||||
out_idx: *mut c_void,
|
||||
rows: i32,
|
||||
cols: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
/// GPU argmax over the last dim of a [rows, cols] BF16 tensor.
|
||||
@@ -19,7 +24,10 @@ pub fn argmax_bf16_to_host(logits: &Tensor) -> Vec<u32> {
|
||||
assert_eq!(logits.ndim(), 2, "argmax expects a 2D [rows, cols] tensor");
|
||||
assert_eq!(logits.dtype(), DType::BF16, "argmax kernel is BF16-only");
|
||||
assert!(logits.is_contiguous(), "argmax requires contiguous input");
|
||||
assert!(matches!(logits.device(), Device::Cuda(_)), "argmax requires GPU input");
|
||||
assert!(
|
||||
matches!(logits.device(), Device::Cuda(_)),
|
||||
"argmax requires GPU input"
|
||||
);
|
||||
|
||||
let rows = logits.shape()[0];
|
||||
let cols = logits.shape()[1];
|
||||
@@ -35,7 +43,8 @@ pub fn argmax_bf16_to_host(logits: &Tensor) -> Vec<u32> {
|
||||
launch_argmax_bf16(
|
||||
logits.data_ptr() as *const c_void,
|
||||
out.as_mut_ptr() as *mut c_void,
|
||||
rows as i32, cols as i32,
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
@@ -44,9 +53,8 @@ pub fn argmax_bf16_to_host(logits: &Tensor) -> Vec<u32> {
|
||||
out.copy_to_host(&mut host_bytes).expect("argmax D2H");
|
||||
drop(out); // returned to pool
|
||||
|
||||
let host_i32: &[i32] = unsafe {
|
||||
std::slice::from_raw_parts(host_bytes.as_ptr() as *const i32, rows)
|
||||
};
|
||||
let host_i32: &[i32] =
|
||||
unsafe { std::slice::from_raw_parts(host_bytes.as_ptr() as *const i32, rows) };
|
||||
host_i32.iter().map(|&v| v as u32).collect()
|
||||
}
|
||||
|
||||
@@ -62,4 +70,3 @@ pub fn argmax_bf16_single(logits: &Tensor) -> u32 {
|
||||
};
|
||||
argmax_bf16_to_host(&view)[0]
|
||||
}
|
||||
|
||||
|
||||
@@ -6,28 +6,67 @@ use crate::gemm::batched_matmul;
|
||||
use crate::softmax::softmax;
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_causal_mask_f32(scores: *mut c_void, batch: i32, rows: i32, cols: i32,
|
||||
offset: i32, stream: *mut c_void);
|
||||
fn launch_causal_mask_bf16(scores: *mut c_void, batch: i32, rows: i32, cols: i32,
|
||||
offset: i32, stream: *mut c_void);
|
||||
fn launch_causal_mask_f32(
|
||||
scores: *mut c_void,
|
||||
batch: i32,
|
||||
rows: i32,
|
||||
cols: i32,
|
||||
offset: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_causal_mask_bf16(
|
||||
scores: *mut c_void,
|
||||
batch: i32,
|
||||
rows: i32,
|
||||
cols: i32,
|
||||
offset: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_flash_attention_bf16(
|
||||
q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
|
||||
batch: i32, num_q_heads: i32, num_kv_heads: i32,
|
||||
q_len: i32, kv_len: i32, head_dim: i32,
|
||||
scale: f32, causal: i32, stream: *mut c_void,
|
||||
q: *const c_void,
|
||||
k: *const c_void,
|
||||
v: *const c_void,
|
||||
o: *mut c_void,
|
||||
batch: i32,
|
||||
num_q_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
q_len: i32,
|
||||
kv_len: i32,
|
||||
head_dim: i32,
|
||||
scale: f32,
|
||||
causal: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_flash_attention_sinks_bf16(
|
||||
q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
|
||||
q: *const c_void,
|
||||
k: *const c_void,
|
||||
v: *const c_void,
|
||||
o: *mut c_void,
|
||||
sinks: *const c_void,
|
||||
batch: i32, num_q_heads: i32, num_kv_heads: i32,
|
||||
q_len: i32, kv_len: i32, head_dim: i32,
|
||||
scale: f32, causal: i32, window_size: i32, stream: *mut c_void,
|
||||
batch: i32,
|
||||
num_q_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
q_len: i32,
|
||||
kv_len: i32,
|
||||
head_dim: i32,
|
||||
scale: f32,
|
||||
causal: i32,
|
||||
window_size: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_decode_attention_bf16(
|
||||
q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
|
||||
batch: i32, num_q_heads: i32, num_kv_heads: i32,
|
||||
kv_len: i32, head_dim: i32,
|
||||
scale: f32, causal: i32, stream: *mut c_void,
|
||||
q: *const c_void,
|
||||
k: *const c_void,
|
||||
v: *const c_void,
|
||||
o: *mut c_void,
|
||||
batch: i32,
|
||||
num_q_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
kv_len: i32,
|
||||
head_dim: i32,
|
||||
scale: f32,
|
||||
causal: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_paged_decode_attention_bf16(
|
||||
q: *const c_void,
|
||||
@@ -36,9 +75,13 @@ unsafe extern "C" {
|
||||
o: *mut c_void,
|
||||
block_tables: *const i32,
|
||||
context_lens: *const i32,
|
||||
batch: i32, num_q_heads: i32, num_kv_heads: i32,
|
||||
head_dim: i32, max_blocks_per_seq: i32,
|
||||
scale: f32, stream: *mut c_void,
|
||||
batch: i32,
|
||||
num_q_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
head_dim: i32,
|
||||
max_blocks_per_seq: i32,
|
||||
scale: f32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_paged_decode_attention_sinks_bf16(
|
||||
q: *const c_void,
|
||||
@@ -48,24 +91,40 @@ unsafe extern "C" {
|
||||
block_tables: *const i32,
|
||||
context_lens: *const i32,
|
||||
sinks: *const c_void,
|
||||
batch: i32, num_q_heads: i32, num_kv_heads: i32,
|
||||
head_dim: i32, max_blocks_per_seq: i32,
|
||||
scale: f32, window_size: i32, stream: *mut c_void,
|
||||
batch: i32,
|
||||
num_q_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
head_dim: i32,
|
||||
max_blocks_per_seq: i32,
|
||||
scale: f32,
|
||||
window_size: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_reshape_and_cache_bf16(
|
||||
k_src: *const c_void, v_src: *const c_void,
|
||||
k_pool: *mut c_void, v_pool: *mut c_void,
|
||||
k_src: *const c_void,
|
||||
v_src: *const c_void,
|
||||
k_pool: *mut c_void,
|
||||
v_pool: *mut c_void,
|
||||
block_ids: *const c_void,
|
||||
num_tokens: i32, num_heads: i32,
|
||||
head_dim: i32, start_pos: i32, block_size: i32,
|
||||
num_tokens: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
start_pos: i32,
|
||||
block_size: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_reshape_and_cache_batched_bf16(
|
||||
k_src: *const c_void, v_src: *const c_void,
|
||||
k_pool: *mut c_void, v_pool: *mut c_void,
|
||||
block_tables: *const c_void, kv_lens: *const c_void,
|
||||
batch: i32, num_heads: i32,
|
||||
head_dim: i32, block_size: i32, max_blocks_per_seq: i32,
|
||||
k_src: *const c_void,
|
||||
v_src: *const c_void,
|
||||
k_pool: *mut c_void,
|
||||
v_pool: *mut c_void,
|
||||
block_tables: *const c_void,
|
||||
kv_lens: *const c_void,
|
||||
batch: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
block_size: i32,
|
||||
max_blocks_per_seq: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
@@ -84,20 +143,30 @@ unsafe extern "C" {
|
||||
/// `block_ids_gpu` must contain at least `(start_pos + num_tokens + block_size - 1) / block_size`
|
||||
/// valid physical block ids.
|
||||
pub unsafe fn reshape_and_cache_bf16(
|
||||
k_src: *const c_void, v_src: *const c_void,
|
||||
k_pool_ptr: *mut c_void, v_pool_ptr: *mut c_void,
|
||||
k_src: *const c_void,
|
||||
v_src: *const c_void,
|
||||
k_pool_ptr: *mut c_void,
|
||||
v_pool_ptr: *mut c_void,
|
||||
block_ids_gpu: *const i32,
|
||||
num_tokens: usize, num_heads: usize,
|
||||
head_dim: usize, start_pos: usize, block_size: usize,
|
||||
num_tokens: usize,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
start_pos: usize,
|
||||
block_size: usize,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
unsafe {
|
||||
launch_reshape_and_cache_bf16(
|
||||
k_src, v_src,
|
||||
k_pool_ptr, v_pool_ptr,
|
||||
k_src,
|
||||
v_src,
|
||||
k_pool_ptr,
|
||||
v_pool_ptr,
|
||||
block_ids_gpu as *const c_void,
|
||||
num_tokens as i32, num_heads as i32,
|
||||
head_dim as i32, start_pos as i32, block_size as i32,
|
||||
num_tokens as i32,
|
||||
num_heads as i32,
|
||||
head_dim as i32,
|
||||
start_pos as i32,
|
||||
block_size as i32,
|
||||
stream,
|
||||
);
|
||||
}
|
||||
@@ -113,21 +182,32 @@ pub unsafe fn reshape_and_cache_bf16(
|
||||
/// All pointers must be on the same GPU. `block_tables` and `kv_lens` must
|
||||
/// already be synced to the device for the active batch.
|
||||
pub unsafe fn reshape_and_cache_batched_bf16(
|
||||
k_src: *const c_void, v_src: *const c_void,
|
||||
k_pool_ptr: *mut c_void, v_pool_ptr: *mut c_void,
|
||||
block_tables_gpu: *const i32, kv_lens_gpu: *const i32,
|
||||
batch: usize, num_heads: usize,
|
||||
head_dim: usize, block_size: usize, max_blocks_per_seq: usize,
|
||||
k_src: *const c_void,
|
||||
v_src: *const c_void,
|
||||
k_pool_ptr: *mut c_void,
|
||||
v_pool_ptr: *mut c_void,
|
||||
block_tables_gpu: *const i32,
|
||||
kv_lens_gpu: *const i32,
|
||||
batch: usize,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
block_size: usize,
|
||||
max_blocks_per_seq: usize,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
unsafe {
|
||||
launch_reshape_and_cache_batched_bf16(
|
||||
k_src, v_src,
|
||||
k_pool_ptr, v_pool_ptr,
|
||||
k_src,
|
||||
v_src,
|
||||
k_pool_ptr,
|
||||
v_pool_ptr,
|
||||
block_tables_gpu as *const c_void,
|
||||
kv_lens_gpu as *const c_void,
|
||||
batch as i32, num_heads as i32,
|
||||
head_dim as i32, block_size as i32, max_blocks_per_seq as i32,
|
||||
batch as i32,
|
||||
num_heads as i32,
|
||||
head_dim as i32,
|
||||
block_size as i32,
|
||||
max_blocks_per_seq as i32,
|
||||
stream,
|
||||
);
|
||||
}
|
||||
@@ -143,12 +223,18 @@ fn apply_causal_mask(scores: &Tensor, offset: usize) {
|
||||
match scores.dtype() {
|
||||
DType::F32 => launch_causal_mask_f32(
|
||||
scores.data_ptr() as *mut c_void,
|
||||
batch as i32, rows as i32, cols as i32, offset as i32,
|
||||
batch as i32,
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
offset as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::BF16 => launch_causal_mask_bf16(
|
||||
scores.data_ptr() as *mut c_void,
|
||||
batch as i32, rows as i32, cols as i32, offset as i32,
|
||||
batch as i32,
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
offset as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for causal mask"),
|
||||
@@ -214,11 +300,7 @@ pub fn decode_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Tensor {
|
||||
let kv_len = k.shape()[2];
|
||||
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
let output = Tensor::empty(
|
||||
&[batch, num_q_heads, 1, head_dim],
|
||||
DType::BF16,
|
||||
q.device(),
|
||||
);
|
||||
let output = Tensor::empty(&[batch, num_q_heads, 1, head_dim], DType::BF16, q.device());
|
||||
|
||||
unsafe {
|
||||
launch_decode_attention_bf16(
|
||||
@@ -266,8 +348,14 @@ pub fn flash_attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tens
|
||||
|
||||
assert_eq!(k.shape(), &[batch, num_kv_heads, kv_len, head_dim]);
|
||||
assert_eq!(v.shape(), &[batch, num_kv_heads, kv_len, head_dim]);
|
||||
assert!(num_q_heads % num_kv_heads == 0, "num_q_heads must be divisible by num_kv_heads");
|
||||
assert!(head_dim <= 128, "flash_attention supports head_dim up to 128");
|
||||
assert!(
|
||||
num_q_heads % num_kv_heads == 0,
|
||||
"num_q_heads must be divisible by num_kv_heads"
|
||||
);
|
||||
assert!(
|
||||
head_dim <= 128,
|
||||
"flash_attention supports head_dim up to 128"
|
||||
);
|
||||
|
||||
// Dispatch to specialized decode kernel for single-token generation
|
||||
if q_len == 1 {
|
||||
@@ -333,10 +421,18 @@ pub fn flash_attention_sinks(
|
||||
assert_eq!(v.shape(), &[batch, num_kv_heads, kv_len, head_dim]);
|
||||
assert!(num_q_heads % num_kv_heads == 0);
|
||||
assert!(head_dim <= 128);
|
||||
assert_eq!(sinks.shape()[0], num_q_heads, "sinks must have num_q_heads entries");
|
||||
assert_eq!(
|
||||
sinks.shape()[0],
|
||||
num_q_heads,
|
||||
"sinks must have num_q_heads entries"
|
||||
);
|
||||
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
let output = Tensor::empty(&[batch, num_q_heads, q_len, head_dim], DType::BF16, q.device());
|
||||
let output = Tensor::empty(
|
||||
&[batch, num_q_heads, q_len, head_dim],
|
||||
DType::BF16,
|
||||
q.device(),
|
||||
);
|
||||
|
||||
unsafe {
|
||||
launch_flash_attention_sinks_bf16(
|
||||
@@ -383,17 +479,20 @@ pub fn paged_decode_attention(
|
||||
max_blocks_per_seq: usize,
|
||||
) -> Tensor {
|
||||
assert_eq!(q.ndim(), 4);
|
||||
assert_eq!(q.shape()[2], 1, "paged_decode_attention requires q_len == 1");
|
||||
assert_eq!(
|
||||
q.shape()[2],
|
||||
1,
|
||||
"paged_decode_attention requires q_len == 1"
|
||||
);
|
||||
assert_eq!(q.dtype(), DType::BF16);
|
||||
assert!(num_q_heads % num_kv_heads == 0, "GQA: num_q_heads must be divisible by num_kv_heads");
|
||||
assert!(
|
||||
num_q_heads % num_kv_heads == 0,
|
||||
"GQA: num_q_heads must be divisible by num_kv_heads"
|
||||
);
|
||||
assert!(head_dim <= 128);
|
||||
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
let output = Tensor::empty(
|
||||
&[batch, num_q_heads, 1, head_dim],
|
||||
DType::BF16,
|
||||
q.device(),
|
||||
);
|
||||
let output = Tensor::empty(&[batch, num_q_heads, 1, head_dim], DType::BF16, q.device());
|
||||
|
||||
unsafe {
|
||||
launch_paged_decode_attention_bf16(
|
||||
@@ -442,11 +541,7 @@ pub fn paged_decode_attention_sinks(
|
||||
assert!(head_dim <= 128);
|
||||
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
let output = Tensor::empty(
|
||||
&[batch, num_q_heads, 1, head_dim],
|
||||
DType::BF16,
|
||||
q.device(),
|
||||
);
|
||||
let output = Tensor::empty(&[batch, num_q_heads, 1, head_dim], DType::BF16, q.device());
|
||||
|
||||
unsafe {
|
||||
launch_paged_decode_attention_sinks_bf16(
|
||||
|
||||
@@ -5,104 +5,302 @@ use std::ffi::c_void;
|
||||
|
||||
// Re-declare the extern functions we need (same as in the individual modules)
|
||||
unsafe extern "C" {
|
||||
fn launch_rmsnorm_bf16(x: *const c_void, gamma: *const c_void, out: *mut c_void,
|
||||
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
|
||||
fn launch_add_rmsnorm_bf16(x: *const c_void, residual: *const c_void, gamma: *const c_void,
|
||||
normed_out: *mut c_void, sum_out: *mut c_void,
|
||||
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
|
||||
fn launch_silu_mul_bf16(gate: *const c_void, up: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_add_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_embedding_bf16(table: *const c_void, token_ids: *const c_void, out: *mut c_void,
|
||||
num_tokens: i32, hidden_size: i32, vocab_size: i32, stream: *mut c_void);
|
||||
fn launch_reshape_heads_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
|
||||
fn launch_merge_heads_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
|
||||
fn launch_transpose_hsd_to_shd_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
|
||||
fn launch_transpose_shd_to_hsd_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
|
||||
fn launch_rope_bf16(x: *mut c_void, cos_cache: *const c_void, sin_cache: *const c_void,
|
||||
positions: *const c_void, num_tokens: i32, num_heads: i32,
|
||||
head_dim: i32, stream: *mut c_void);
|
||||
fn launch_gemv_bf16(x: *const c_void, w: *const c_void, y_bf16: *mut c_void, y_fp32_buf: *mut c_void,
|
||||
k: i32, n: i32, stream: *mut c_void);
|
||||
fn launch_rmsnorm_bf16(
|
||||
x: *const c_void,
|
||||
gamma: *const c_void,
|
||||
out: *mut c_void,
|
||||
rows: i32,
|
||||
hidden_size: i32,
|
||||
eps: f32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_add_rmsnorm_bf16(
|
||||
x: *const c_void,
|
||||
residual: *const c_void,
|
||||
gamma: *const c_void,
|
||||
normed_out: *mut c_void,
|
||||
sum_out: *mut c_void,
|
||||
rows: i32,
|
||||
hidden_size: i32,
|
||||
eps: f32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_silu_mul_bf16(
|
||||
gate: *const c_void,
|
||||
up: *const c_void,
|
||||
out: *mut c_void,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_add_bf16(
|
||||
a: *const c_void,
|
||||
b: *const c_void,
|
||||
out: *mut c_void,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_embedding_bf16(
|
||||
table: *const c_void,
|
||||
token_ids: *const c_void,
|
||||
out: *mut c_void,
|
||||
num_tokens: i32,
|
||||
hidden_size: i32,
|
||||
vocab_size: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_reshape_heads_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_merge_heads_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_transpose_hsd_to_shd_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_transpose_shd_to_hsd_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_rope_bf16(
|
||||
x: *mut c_void,
|
||||
cos_cache: *const c_void,
|
||||
sin_cache: *const c_void,
|
||||
positions: *const c_void,
|
||||
num_tokens: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_gemv_bf16(
|
||||
x: *const c_void,
|
||||
w: *const c_void,
|
||||
y_bf16: *mut c_void,
|
||||
y_fp32_buf: *mut c_void,
|
||||
k: i32,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_decode_attention_bf16(
|
||||
q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
|
||||
batch: i32, num_q_heads: i32, num_kv_heads: i32,
|
||||
kv_len: i32, head_dim: i32,
|
||||
scale: f32, causal: i32, stream: *mut c_void,
|
||||
q: *const c_void,
|
||||
k: *const c_void,
|
||||
v: *const c_void,
|
||||
o: *mut c_void,
|
||||
batch: i32,
|
||||
num_q_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
kv_len: i32,
|
||||
head_dim: i32,
|
||||
scale: f32,
|
||||
causal: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
/// Raw rmsnorm dispatch: writes to pre-allocated `out`.
|
||||
pub unsafe fn rmsnorm_bf16(x: *const c_void, gamma: *const c_void, out: *mut c_void,
|
||||
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void) {
|
||||
pub unsafe fn rmsnorm_bf16(
|
||||
x: *const c_void,
|
||||
gamma: *const c_void,
|
||||
out: *mut c_void,
|
||||
rows: i32,
|
||||
hidden_size: i32,
|
||||
eps: f32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_rmsnorm_bf16(x, gamma, out, rows, hidden_size, eps, stream);
|
||||
}
|
||||
|
||||
/// Raw add_rmsnorm dispatch.
|
||||
pub unsafe fn add_rmsnorm_bf16(x: *const c_void, residual: *const c_void, gamma: *const c_void,
|
||||
normed_out: *mut c_void, sum_out: *mut c_void,
|
||||
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void) {
|
||||
launch_add_rmsnorm_bf16(x, residual, gamma, normed_out, sum_out, rows, hidden_size, eps, stream);
|
||||
pub unsafe fn add_rmsnorm_bf16(
|
||||
x: *const c_void,
|
||||
residual: *const c_void,
|
||||
gamma: *const c_void,
|
||||
normed_out: *mut c_void,
|
||||
sum_out: *mut c_void,
|
||||
rows: i32,
|
||||
hidden_size: i32,
|
||||
eps: f32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_add_rmsnorm_bf16(
|
||||
x,
|
||||
residual,
|
||||
gamma,
|
||||
normed_out,
|
||||
sum_out,
|
||||
rows,
|
||||
hidden_size,
|
||||
eps,
|
||||
stream,
|
||||
);
|
||||
}
|
||||
|
||||
/// Raw silu_mul dispatch.
|
||||
pub unsafe fn silu_mul_bf16(gate: *const c_void, up: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void) {
|
||||
pub unsafe fn silu_mul_bf16(
|
||||
gate: *const c_void,
|
||||
up: *const c_void,
|
||||
out: *mut c_void,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_silu_mul_bf16(gate, up, out, n, stream);
|
||||
}
|
||||
|
||||
/// Raw add dispatch.
|
||||
pub unsafe fn add_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void) {
|
||||
pub unsafe fn add_bf16(
|
||||
a: *const c_void,
|
||||
b: *const c_void,
|
||||
out: *mut c_void,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_add_bf16(a, b, out, n, stream);
|
||||
}
|
||||
|
||||
/// Raw embedding dispatch.
|
||||
pub unsafe fn embedding_bf16(table: *const c_void, token_ids: *const c_void, out: *mut c_void,
|
||||
num_tokens: i32, hidden_size: i32, vocab_size: i32, stream: *mut c_void) {
|
||||
launch_embedding_bf16(table, token_ids, out, num_tokens, hidden_size, vocab_size, stream);
|
||||
pub unsafe fn embedding_bf16(
|
||||
table: *const c_void,
|
||||
token_ids: *const c_void,
|
||||
out: *mut c_void,
|
||||
num_tokens: i32,
|
||||
hidden_size: i32,
|
||||
vocab_size: i32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_embedding_bf16(
|
||||
table,
|
||||
token_ids,
|
||||
out,
|
||||
num_tokens,
|
||||
hidden_size,
|
||||
vocab_size,
|
||||
stream,
|
||||
);
|
||||
}
|
||||
|
||||
/// Raw reshape_heads dispatch.
|
||||
pub unsafe fn reshape_heads_bf16(inp: *const c_void, out: *mut c_void,
|
||||
seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void) {
|
||||
pub unsafe fn reshape_heads_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_reshape_heads_bf16(inp, out, seq_len, num_heads, head_dim, stream);
|
||||
}
|
||||
|
||||
/// Raw merge_heads dispatch.
|
||||
pub unsafe fn merge_heads_bf16(inp: *const c_void, out: *mut c_void,
|
||||
seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void) {
|
||||
pub unsafe fn merge_heads_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_merge_heads_bf16(inp, out, seq_len, num_heads, head_dim, stream);
|
||||
}
|
||||
|
||||
/// Raw transpose HSD->SHD dispatch.
|
||||
pub unsafe fn transpose_hsd_to_shd_bf16(inp: *const c_void, out: *mut c_void,
|
||||
seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void) {
|
||||
pub unsafe fn transpose_hsd_to_shd_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_transpose_hsd_to_shd_bf16(inp, out, seq_len, num_heads, head_dim, stream);
|
||||
}
|
||||
|
||||
/// Raw transpose SHD->HSD dispatch.
|
||||
pub unsafe fn transpose_shd_to_hsd_bf16(inp: *const c_void, out: *mut c_void,
|
||||
seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void) {
|
||||
pub unsafe fn transpose_shd_to_hsd_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_transpose_shd_to_hsd_bf16(inp, out, seq_len, num_heads, head_dim, stream);
|
||||
}
|
||||
|
||||
/// Raw RoPE dispatch (in-place).
|
||||
pub unsafe fn rope_bf16(x: *mut c_void, cos_cache: *const c_void, sin_cache: *const c_void,
|
||||
positions: *const c_void, num_tokens: i32, num_heads: i32,
|
||||
head_dim: i32, stream: *mut c_void) {
|
||||
launch_rope_bf16(x, cos_cache, sin_cache, positions, num_tokens, num_heads, head_dim, stream);
|
||||
pub unsafe fn rope_bf16(
|
||||
x: *mut c_void,
|
||||
cos_cache: *const c_void,
|
||||
sin_cache: *const c_void,
|
||||
positions: *const c_void,
|
||||
num_tokens: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_rope_bf16(
|
||||
x, cos_cache, sin_cache, positions, num_tokens, num_heads, head_dim, stream,
|
||||
);
|
||||
}
|
||||
|
||||
/// Raw GEMV dispatch (BF16, M=1). Caller must provide fp32 accumulator buffer.
|
||||
pub unsafe fn gemv_bf16(x: *const c_void, w: *const c_void, y_bf16: *mut c_void,
|
||||
y_fp32_buf: *mut c_void, k: i32, n: i32, stream: *mut c_void) {
|
||||
pub unsafe fn gemv_bf16(
|
||||
x: *const c_void,
|
||||
w: *const c_void,
|
||||
y_bf16: *mut c_void,
|
||||
y_fp32_buf: *mut c_void,
|
||||
k: i32,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_gemv_bf16(x, w, y_bf16, y_fp32_buf, k, n, stream);
|
||||
}
|
||||
|
||||
/// Raw decode attention dispatch.
|
||||
pub unsafe fn decode_attention_bf16(q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
|
||||
batch: i32, num_q_heads: i32, num_kv_heads: i32,
|
||||
kv_len: i32, head_dim: i32,
|
||||
scale: f32, stream: *mut c_void) {
|
||||
launch_decode_attention_bf16(q, k, v, o, batch, num_q_heads, num_kv_heads, kv_len, head_dim, scale, 1, stream);
|
||||
pub unsafe fn decode_attention_bf16(
|
||||
q: *const c_void,
|
||||
k: *const c_void,
|
||||
v: *const c_void,
|
||||
o: *mut c_void,
|
||||
batch: i32,
|
||||
num_q_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
kv_len: i32,
|
||||
head_dim: i32,
|
||||
scale: f32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_decode_attention_bf16(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
batch,
|
||||
num_q_heads,
|
||||
num_kv_heads,
|
||||
kv_len,
|
||||
head_dim,
|
||||
scale,
|
||||
1,
|
||||
stream,
|
||||
);
|
||||
}
|
||||
|
||||
// cuBLAS FFI
|
||||
|
||||
@@ -2,10 +2,24 @@ use std::ffi::c_void;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_embedding_f32(table: *const c_void, token_ids: *const c_void, out: *mut c_void,
|
||||
num_tokens: i32, hidden_size: i32, vocab_size: i32, stream: *mut c_void);
|
||||
fn launch_embedding_bf16(table: *const c_void, token_ids: *const c_void, out: *mut c_void,
|
||||
num_tokens: i32, hidden_size: i32, vocab_size: i32, stream: *mut c_void);
|
||||
fn launch_embedding_f32(
|
||||
table: *const c_void,
|
||||
token_ids: *const c_void,
|
||||
out: *mut c_void,
|
||||
num_tokens: i32,
|
||||
hidden_size: i32,
|
||||
vocab_size: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_embedding_bf16(
|
||||
table: *const c_void,
|
||||
token_ids: *const c_void,
|
||||
out: *mut c_void,
|
||||
num_tokens: i32,
|
||||
hidden_size: i32,
|
||||
vocab_size: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
/// Embedding lookup: table[token_ids[i]] for each i.
|
||||
@@ -18,8 +32,14 @@ pub fn embedding(table: &Tensor, token_ids: &[u32]) -> Tensor {
|
||||
let hidden_size = table.shape()[1];
|
||||
let num_tokens = token_ids.len();
|
||||
let vocab_size = table.shape()[0];
|
||||
assert!(num_tokens <= i32::MAX as usize, "too many tokens for i32 kernel param");
|
||||
assert!(hidden_size <= i32::MAX as usize, "hidden_size too large for i32 kernel param");
|
||||
assert!(
|
||||
num_tokens <= i32::MAX as usize,
|
||||
"too many tokens for i32 kernel param"
|
||||
);
|
||||
assert!(
|
||||
hidden_size <= i32::MAX as usize,
|
||||
"hidden_size too large for i32 kernel param"
|
||||
);
|
||||
|
||||
// Upload token_ids to GPU
|
||||
let ids_bytes = unsafe {
|
||||
@@ -28,11 +48,15 @@ pub fn embedding(table: &Tensor, token_ids: &[u32]) -> Tensor {
|
||||
num_tokens * std::mem::size_of::<u32>(),
|
||||
)
|
||||
};
|
||||
let mut ids_gpu = xserv_cuda::allocator::cached_alloc(ids_bytes.len()).expect("alloc token_ids");
|
||||
let mut ids_gpu =
|
||||
xserv_cuda::allocator::cached_alloc(ids_bytes.len()).expect("alloc token_ids");
|
||||
ids_gpu.copy_from_host(ids_bytes).unwrap();
|
||||
|
||||
for &tid in token_ids {
|
||||
assert!((tid as usize) < vocab_size, "token_id {tid} out of bounds (vocab_size={vocab_size})");
|
||||
assert!(
|
||||
(tid as usize) < vocab_size,
|
||||
"token_id {tid} out of bounds (vocab_size={vocab_size})"
|
||||
);
|
||||
}
|
||||
|
||||
embedding_device_ids(table, ids_gpu.as_ptr() as *const c_void, num_tokens)
|
||||
@@ -53,14 +77,22 @@ pub fn embedding_device_ids(table: &Tensor, ids_gpu: *const c_void, num_tokens:
|
||||
unsafe {
|
||||
match table.dtype() {
|
||||
DType::F32 => launch_embedding_f32(
|
||||
table.data_ptr() as _, ids_gpu,
|
||||
table.data_ptr() as _,
|
||||
ids_gpu,
|
||||
out.data_ptr() as *mut c_void,
|
||||
num_tokens as i32, hidden_size as i32, vocab_size as i32, xserv_cuda::current_stream_raw(),
|
||||
num_tokens as i32,
|
||||
hidden_size as i32,
|
||||
vocab_size as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::BF16 => launch_embedding_bf16(
|
||||
table.data_ptr() as _, ids_gpu,
|
||||
table.data_ptr() as _,
|
||||
ids_gpu,
|
||||
out.data_ptr() as *mut c_void,
|
||||
num_tokens as i32, hidden_size as i32, vocab_size as i32, xserv_cuda::current_stream_raw(),
|
||||
num_tokens as i32,
|
||||
hidden_size as i32,
|
||||
vocab_size as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for embedding"),
|
||||
}
|
||||
|
||||
@@ -1,14 +1,22 @@
|
||||
use std::cell::RefCell;
|
||||
use std::ffi::c_void;
|
||||
use xserv_cuda::error::{self, Result};
|
||||
use xserv_cuda::GpuBuffer;
|
||||
use xserv_cuda::error::{self, Result};
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
const CUBLAS_WORKSPACE_BYTES: usize = 32 * 1024 * 1024;
|
||||
|
||||
// GEMV: single-kernel, no FP32 temp buffer needed
|
||||
unsafe extern "C" {
|
||||
fn launch_gemv_bf16(x: *const c_void, w: *const c_void, y_bf16: *mut c_void, y_fp32_buf: *mut c_void, k: i32, n: i32, stream: *mut c_void);
|
||||
fn launch_gemv_bf16(
|
||||
x: *const c_void,
|
||||
w: *const c_void,
|
||||
y_bf16: *mut c_void,
|
||||
y_fp32_buf: *mut c_void,
|
||||
k: i32,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
@@ -20,10 +28,42 @@ pub enum GemmBackend {
|
||||
|
||||
// --- FFI: custom CUDA kernels ---
|
||||
unsafe extern "C" {
|
||||
fn launch_gemm_naive_f32(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
|
||||
fn launch_gemm_naive_bf16(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
|
||||
fn launch_gemm_tiled_f32(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
|
||||
fn launch_gemm_tiled_bf16(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
|
||||
fn launch_gemm_naive_f32(
|
||||
a: *const c_void,
|
||||
b: *const c_void,
|
||||
c: *mut c_void,
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_gemm_naive_bf16(
|
||||
a: *const c_void,
|
||||
b: *const c_void,
|
||||
c: *mut c_void,
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_gemm_tiled_f32(
|
||||
a: *const c_void,
|
||||
b: *const c_void,
|
||||
c: *mut c_void,
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_gemm_tiled_bf16(
|
||||
a: *const c_void,
|
||||
b: *const c_void,
|
||||
c: *mut c_void,
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
// --- FFI: cuBLAS ---
|
||||
@@ -46,25 +86,46 @@ unsafe extern "C" {
|
||||
fn cublasSetWorkspace_v2(handle: CublasHandle, workspace: *mut c_void, size: usize) -> i32;
|
||||
fn cublasGemmEx(
|
||||
handle: CublasHandle,
|
||||
transa: i32, transb: i32,
|
||||
m: i32, n: i32, k: i32,
|
||||
transa: i32,
|
||||
transb: i32,
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
alpha: *const c_void,
|
||||
a: *const c_void, a_type: i32, lda: i32,
|
||||
b: *const c_void, b_type: i32, ldb: i32,
|
||||
a: *const c_void,
|
||||
a_type: i32,
|
||||
lda: i32,
|
||||
b: *const c_void,
|
||||
b_type: i32,
|
||||
ldb: i32,
|
||||
beta: *const c_void,
|
||||
c: *mut c_void, c_type: i32, ldc: i32,
|
||||
c: *mut c_void,
|
||||
c_type: i32,
|
||||
ldc: i32,
|
||||
compute_type: i32,
|
||||
algo: i32,
|
||||
) -> i32;
|
||||
fn cublasGemmStridedBatchedEx(
|
||||
handle: CublasHandle,
|
||||
transa: i32, transb: i32,
|
||||
m: i32, n: i32, k: i32,
|
||||
transa: i32,
|
||||
transb: i32,
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
alpha: *const c_void,
|
||||
a: *const c_void, a_type: i32, lda: i32, stride_a: i64,
|
||||
b: *const c_void, b_type: i32, ldb: i32, stride_b: i64,
|
||||
a: *const c_void,
|
||||
a_type: i32,
|
||||
lda: i32,
|
||||
stride_a: i64,
|
||||
b: *const c_void,
|
||||
b_type: i32,
|
||||
ldb: i32,
|
||||
stride_b: i64,
|
||||
beta: *const c_void,
|
||||
c: *mut c_void, c_type: i32, ldc: i32, stride_c: i64,
|
||||
c: *mut c_void,
|
||||
c_type: i32,
|
||||
ldc: i32,
|
||||
stride_c: i64,
|
||||
batch_count: i32,
|
||||
compute_type: i32,
|
||||
algo: i32,
|
||||
@@ -89,9 +150,16 @@ impl CublasContext {
|
||||
// set, so we keep the GpuBuffer in this struct.
|
||||
let mut workspace = GpuBuffer::alloc(CUBLAS_WORKSPACE_BYTES)?;
|
||||
error::check(unsafe {
|
||||
cublasSetWorkspace_v2(handle, workspace.as_mut_ptr() as *mut c_void, CUBLAS_WORKSPACE_BYTES)
|
||||
cublasSetWorkspace_v2(
|
||||
handle,
|
||||
workspace.as_mut_ptr() as *mut c_void,
|
||||
CUBLAS_WORKSPACE_BYTES,
|
||||
)
|
||||
})?;
|
||||
Ok(Self { handle, _workspace: Some(workspace) })
|
||||
Ok(Self {
|
||||
handle,
|
||||
_workspace: Some(workspace),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -123,9 +191,7 @@ where
|
||||
|
||||
/// Get the thread-local cuBLAS handle for use with dispatch module.
|
||||
pub fn cublas_handle() -> CublasHandle {
|
||||
CUBLAS_CTX.with(|cell| {
|
||||
cell.borrow().handle
|
||||
})
|
||||
CUBLAS_CTX.with(|cell| cell.borrow().handle)
|
||||
}
|
||||
|
||||
/// Matrix multiplication: C = A @ B
|
||||
@@ -136,8 +202,14 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
|
||||
assert_eq!(b.ndim(), 2);
|
||||
assert_eq!(a.shape()[1], b.shape()[0], "inner dimension mismatch");
|
||||
assert_eq!(a.dtype(), b.dtype(), "dtype mismatch");
|
||||
assert!(a.is_contiguous() && b.is_contiguous(), "matmul requires contiguous tensors");
|
||||
assert!(matches!(a.device(), Device::Cuda(_)), "matmul requires GPU tensors");
|
||||
assert!(
|
||||
a.is_contiguous() && b.is_contiguous(),
|
||||
"matmul requires contiguous tensors"
|
||||
);
|
||||
assert!(
|
||||
matches!(a.device(), Device::Cuda(_)),
|
||||
"matmul requires GPU tensors"
|
||||
);
|
||||
|
||||
let m = a.shape()[0];
|
||||
let k = a.shape()[1];
|
||||
@@ -154,32 +226,63 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
|
||||
let null_stream = xserv_cuda::current_stream_raw();
|
||||
|
||||
match backend {
|
||||
GemmBackend::Naive => {
|
||||
unsafe {
|
||||
GemmBackend::Naive => unsafe {
|
||||
match dtype {
|
||||
DType::F32 => launch_gemm_naive_f32(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
|
||||
DType::BF16 => launch_gemm_naive_bf16(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
|
||||
DType::F32 => launch_gemm_naive_f32(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
m as i32,
|
||||
n as i32,
|
||||
k as i32,
|
||||
null_stream,
|
||||
),
|
||||
DType::BF16 => launch_gemm_naive_bf16(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
m as i32,
|
||||
n as i32,
|
||||
k as i32,
|
||||
null_stream,
|
||||
),
|
||||
_ => panic!("unsupported dtype for naive GEMM"),
|
||||
}
|
||||
}
|
||||
}
|
||||
GemmBackend::Tiled => {
|
||||
unsafe {
|
||||
},
|
||||
GemmBackend::Tiled => unsafe {
|
||||
match dtype {
|
||||
DType::F32 => launch_gemm_tiled_f32(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
|
||||
DType::BF16 => launch_gemm_tiled_bf16(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
|
||||
DType::F32 => launch_gemm_tiled_f32(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
m as i32,
|
||||
n as i32,
|
||||
k as i32,
|
||||
null_stream,
|
||||
),
|
||||
DType::BF16 => launch_gemm_tiled_bf16(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
m as i32,
|
||||
n as i32,
|
||||
k as i32,
|
||||
null_stream,
|
||||
),
|
||||
_ => panic!("unsupported dtype for tiled GEMM"),
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
GemmBackend::CuBlas => {
|
||||
if m == 1 && dtype == DType::BF16 && n >= 256 {
|
||||
let mut fp32_buf = xserv_cuda::allocator::cached_alloc(n * 4).unwrap();
|
||||
unsafe {
|
||||
launch_gemv_bf16(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
fp32_buf.as_mut_ptr() as *mut c_void,
|
||||
k as i32, n as i32,
|
||||
k as i32,
|
||||
n as i32,
|
||||
null_stream,
|
||||
);
|
||||
}
|
||||
@@ -197,16 +300,26 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
|
||||
cublasSetStream_v2(handle, null_stream);
|
||||
error::check(cublasGemmEx(
|
||||
handle,
|
||||
CUBLAS_OP_N, CUBLAS_OP_N,
|
||||
n as i32, m as i32, k as i32,
|
||||
CUBLAS_OP_N,
|
||||
CUBLAS_OP_N,
|
||||
n as i32,
|
||||
m as i32,
|
||||
k as i32,
|
||||
&alpha as *const f32 as *const c_void,
|
||||
b_ptr, b_type, n as i32,
|
||||
a_ptr, a_type, k as i32,
|
||||
b_ptr,
|
||||
b_type,
|
||||
n as i32,
|
||||
a_ptr,
|
||||
a_type,
|
||||
k as i32,
|
||||
&beta as *const f32 as *const c_void,
|
||||
c_ptr, c_type, n as i32,
|
||||
c_ptr,
|
||||
c_type,
|
||||
n as i32,
|
||||
CUBLAS_COMPUTE_32F,
|
||||
-1,
|
||||
)).expect("cuBLAS GEMM failed");
|
||||
))
|
||||
.expect("cuBLAS GEMM failed");
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -264,17 +377,30 @@ pub fn batched_matmul(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
// Row-major trick: C = A @ B ⟺ C^T = B^T @ A^T (col-major)
|
||||
error::check(cublasGemmStridedBatchedEx(
|
||||
handle,
|
||||
CUBLAS_OP_N, CUBLAS_OP_N,
|
||||
n as i32, m as i32, k as i32,
|
||||
CUBLAS_OP_N,
|
||||
CUBLAS_OP_N,
|
||||
n as i32,
|
||||
m as i32,
|
||||
k as i32,
|
||||
&alpha as *const f32 as *const c_void,
|
||||
b.data_ptr() as _, b_type, n as i32, stride_b,
|
||||
a.data_ptr() as _, a_type, k as i32, stride_a,
|
||||
b.data_ptr() as _,
|
||||
b_type,
|
||||
n as i32,
|
||||
stride_b,
|
||||
a.data_ptr() as _,
|
||||
a_type,
|
||||
k as i32,
|
||||
stride_a,
|
||||
&beta as *const f32 as *const c_void,
|
||||
c.data_ptr() as *mut c_void, c_type, n as i32, stride_c,
|
||||
c.data_ptr() as *mut c_void,
|
||||
c_type,
|
||||
n as i32,
|
||||
stride_c,
|
||||
batch as i32,
|
||||
CUBLAS_COMPUTE_32F,
|
||||
-1,
|
||||
)).expect("cuBLAS batched GEMM failed");
|
||||
))
|
||||
.expect("cuBLAS batched GEMM failed");
|
||||
});
|
||||
c
|
||||
}
|
||||
|
||||
@@ -2,10 +2,26 @@ use std::ffi::c_void;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_layernorm_f32(x: *const c_void, gamma: *const c_void, beta: *const c_void,
|
||||
out: *mut c_void, rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
|
||||
fn launch_layernorm_bf16(x: *const c_void, gamma: *const c_void, beta: *const c_void,
|
||||
out: *mut c_void, rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
|
||||
fn launch_layernorm_f32(
|
||||
x: *const c_void,
|
||||
gamma: *const c_void,
|
||||
beta: *const c_void,
|
||||
out: *mut c_void,
|
||||
rows: i32,
|
||||
hidden_size: i32,
|
||||
eps: f32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_layernorm_bf16(
|
||||
x: *const c_void,
|
||||
gamma: *const c_void,
|
||||
beta: *const c_void,
|
||||
out: *mut c_void,
|
||||
rows: i32,
|
||||
hidden_size: i32,
|
||||
eps: f32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
pub fn layernorm(x: &Tensor, gamma: &Tensor, beta: &Tensor, eps: f32) -> Tensor {
|
||||
@@ -17,21 +33,37 @@ pub fn layernorm(x: &Tensor, gamma: &Tensor, beta: &Tensor, eps: f32) -> Tensor
|
||||
assert_eq!(beta.shape(), &[hidden_size]);
|
||||
|
||||
let rows = x.numel() / hidden_size;
|
||||
assert!(rows <= i32::MAX as usize, "too many rows for i32 kernel param");
|
||||
assert!(hidden_size <= i32::MAX as usize, "hidden_size too large for i32 kernel param");
|
||||
assert!(
|
||||
rows <= i32::MAX as usize,
|
||||
"too many rows for i32 kernel param"
|
||||
);
|
||||
assert!(
|
||||
hidden_size <= i32::MAX as usize,
|
||||
"hidden_size too large for i32 kernel param"
|
||||
);
|
||||
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
||||
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => launch_layernorm_f32(
|
||||
x.data_ptr() as _, gamma.data_ptr() as _, beta.data_ptr() as _,
|
||||
x.data_ptr() as _,
|
||||
gamma.data_ptr() as _,
|
||||
beta.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
rows as i32, hidden_size as i32, eps, xserv_cuda::current_stream_raw(),
|
||||
rows as i32,
|
||||
hidden_size as i32,
|
||||
eps,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::BF16 => launch_layernorm_bf16(
|
||||
x.data_ptr() as _, gamma.data_ptr() as _, beta.data_ptr() as _,
|
||||
x.data_ptr() as _,
|
||||
gamma.data_ptr() as _,
|
||||
beta.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
rows as i32, hidden_size as i32, eps, xserv_cuda::current_stream_raw(),
|
||||
rows as i32,
|
||||
hidden_size as i32,
|
||||
eps,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for layernorm"),
|
||||
}
|
||||
|
||||
@@ -14,14 +14,20 @@ pub mod transpose;
|
||||
|
||||
pub use activation::{add, bias_add_2d, gelu, gpt_oss_glu, mul, scale, silu, silu_mul};
|
||||
pub use argmax::{argmax_bf16_single, argmax_bf16_to_host};
|
||||
pub use transpose::{merge_heads_gpu, repeat_kv_gpu, reshape_heads_gpu, strided_to_contiguous_gpu, transpose_for_rope_gpu, transpose_from_rope_gpu};
|
||||
pub use attention::{attention, decode_attention, flash_attention, flash_attention_sinks, paged_decode_attention, paged_decode_attention_sinks, reshape_and_cache_bf16, reshape_and_cache_batched_bf16};
|
||||
pub use attention::{
|
||||
attention, decode_attention, flash_attention, flash_attention_sinks, paged_decode_attention,
|
||||
paged_decode_attention_sinks, reshape_and_cache_batched_bf16, reshape_and_cache_bf16,
|
||||
};
|
||||
pub use embedding::{embedding, embedding_device_ids};
|
||||
pub use gemm::{batched_matmul, matmul, GemmBackend};
|
||||
pub use gemm::{GemmBackend, batched_matmul, matmul};
|
||||
pub use layernorm::layernorm;
|
||||
pub use rmsnorm::{add_rmsnorm, rmsnorm};
|
||||
pub use rope::{rope_inplace, rope_inplace_device_pos, RopeCache};
|
||||
pub use rope::{RopeCache, rope_inplace, rope_inplace_device_pos};
|
||||
pub use softmax::softmax;
|
||||
pub use transpose::{
|
||||
merge_heads_gpu, repeat_kv_gpu, reshape_heads_gpu, strided_to_contiguous_gpu,
|
||||
transpose_for_rope_gpu, transpose_from_rope_gpu,
|
||||
};
|
||||
|
||||
/// Register GPU kernels with the tensor crate. Call once at startup.
|
||||
pub fn init() {
|
||||
|
||||
@@ -1,66 +1,113 @@
|
||||
use std::ffi::c_void;
|
||||
use xserv_tensor::{DType, Tensor};
|
||||
|
||||
use crate::gemm::{cublas_handle, CublasHandle};
|
||||
use crate::gemm::{CublasHandle, cublas_handle};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_moe_topk_softmax_bf16(
|
||||
router_logits: *const c_void,
|
||||
topk_ids: *mut c_void, topk_weights: *mut c_void,
|
||||
num_tokens: i32, num_experts: i32, top_k: i32,
|
||||
topk_ids: *mut c_void,
|
||||
topk_weights: *mut c_void,
|
||||
num_tokens: i32,
|
||||
num_experts: i32,
|
||||
top_k: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_moe_replicate_bf16(
|
||||
x: *const c_void, x_rep: *mut c_void,
|
||||
num_tokens: i32, hidden: i32, local_experts: i32,
|
||||
x: *const c_void,
|
||||
x_rep: *mut c_void,
|
||||
num_tokens: i32,
|
||||
hidden: i32,
|
||||
local_experts: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_moe_bias_add_3d_bf16(
|
||||
x: *mut c_void, bias: *const c_void,
|
||||
batch: i32, num_tokens: i32, dim: i32,
|
||||
x: *mut c_void,
|
||||
bias: *const c_void,
|
||||
batch: i32,
|
||||
num_tokens: i32,
|
||||
dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_moe_weighted_sum_bf16(
|
||||
expert_out: *const c_void,
|
||||
topk_ids: *const c_void, topk_weights: *const c_void,
|
||||
topk_ids: *const c_void,
|
||||
topk_weights: *const c_void,
|
||||
out: *mut c_void,
|
||||
num_tokens: i32, hidden: i32, top_k: i32,
|
||||
expert_start: i32, local_experts: i32,
|
||||
num_tokens: i32,
|
||||
hidden: i32,
|
||||
top_k: i32,
|
||||
expert_start: i32,
|
||||
local_experts: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
|
||||
fn launch_moe_sparse_gemv_fp8_bf16(
|
||||
x: *const c_void, w: *const c_void, w_scales: *const c_void,
|
||||
bias: *const c_void, topk_ids: *const c_void, y: *mut c_void,
|
||||
num_tokens: i32, n: i32, k: i32, top_k: i32,
|
||||
expert_start: i32, local_experts: i32, x_per_slot: i32,
|
||||
x: *const c_void,
|
||||
w: *const c_void,
|
||||
w_scales: *const c_void,
|
||||
bias: *const c_void,
|
||||
topk_ids: *const c_void,
|
||||
y: *mut c_void,
|
||||
num_tokens: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
top_k: i32,
|
||||
expert_start: i32,
|
||||
local_experts: i32,
|
||||
x_per_slot: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_moe_sparse_gemv_mxfp4_bf16(
|
||||
x: *const c_void, w_packed: *const c_void, w_scales: *const c_void,
|
||||
bias: *const c_void, topk_ids: *const c_void, y: *mut c_void,
|
||||
num_tokens: i32, n: i32, k: i32, top_k: i32,
|
||||
expert_start: i32, local_experts: i32, x_per_slot: i32,
|
||||
x: *const c_void,
|
||||
w_packed: *const c_void,
|
||||
w_scales: *const c_void,
|
||||
bias: *const c_void,
|
||||
topk_ids: *const c_void,
|
||||
y: *mut c_void,
|
||||
num_tokens: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
top_k: i32,
|
||||
expert_start: i32,
|
||||
local_experts: i32,
|
||||
x_per_slot: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_moe_weighted_sum_sparse_bf16(
|
||||
down: *const c_void,
|
||||
topk_ids: *const c_void, topk_weights: *const c_void,
|
||||
topk_ids: *const c_void,
|
||||
topk_weights: *const c_void,
|
||||
out: *mut c_void,
|
||||
num_tokens: i32, hidden: i32, top_k: i32,
|
||||
expert_start: i32, local_experts: i32,
|
||||
num_tokens: i32,
|
||||
hidden: i32,
|
||||
top_k: i32,
|
||||
expert_start: i32,
|
||||
local_experts: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
|
||||
fn cublasGemmStridedBatchedEx(
|
||||
handle: CublasHandle,
|
||||
transa: i32, transb: i32,
|
||||
m: i32, n: i32, k: i32,
|
||||
transa: i32,
|
||||
transb: i32,
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
alpha: *const c_void,
|
||||
a: *const c_void, a_type: i32, lda: i32, stride_a: i64,
|
||||
b: *const c_void, b_type: i32, ldb: i32, stride_b: i64,
|
||||
a: *const c_void,
|
||||
a_type: i32,
|
||||
lda: i32,
|
||||
stride_a: i64,
|
||||
b: *const c_void,
|
||||
b_type: i32,
|
||||
ldb: i32,
|
||||
stride_b: i64,
|
||||
beta: *const c_void,
|
||||
c: *mut c_void, c_type: i32, ldc: i32, stride_c: i64,
|
||||
c: *mut c_void,
|
||||
c_type: i32,
|
||||
ldc: i32,
|
||||
stride_c: i64,
|
||||
batch_count: i32,
|
||||
compute_type: i32,
|
||||
algo: i32,
|
||||
@@ -99,7 +146,9 @@ pub fn moe_topk_softmax(
|
||||
router_logits.data_ptr() as *const c_void,
|
||||
topk_ids.data_ptr() as *mut c_void,
|
||||
topk_weights.data_ptr() as *mut c_void,
|
||||
num_tokens as i32, num_experts as i32, top_k as i32,
|
||||
num_tokens as i32,
|
||||
num_experts as i32,
|
||||
top_k as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
@@ -114,13 +163,19 @@ pub fn moe_replicate(x: &Tensor, local_experts: usize) -> Tensor {
|
||||
assert!(x.is_contiguous());
|
||||
let num_tokens = x.shape()[0];
|
||||
let hidden = x.shape()[1];
|
||||
let out = Tensor::empty(&[local_experts, num_tokens, hidden], DType::BF16, x.device());
|
||||
let out = Tensor::empty(
|
||||
&[local_experts, num_tokens, hidden],
|
||||
DType::BF16,
|
||||
x.device(),
|
||||
);
|
||||
|
||||
unsafe {
|
||||
launch_moe_replicate_bf16(
|
||||
x.data_ptr() as *const c_void,
|
||||
out.data_ptr() as *mut c_void,
|
||||
num_tokens as i32, hidden as i32, local_experts as i32,
|
||||
num_tokens as i32,
|
||||
hidden as i32,
|
||||
local_experts as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
@@ -143,7 +198,9 @@ pub fn moe_bias_add_3d(x: &Tensor, bias: &Tensor) {
|
||||
launch_moe_bias_add_3d_bf16(
|
||||
x.data_ptr() as *mut c_void,
|
||||
bias.data_ptr() as *const c_void,
|
||||
batch as i32, num_tokens as i32, dim as i32,
|
||||
batch as i32,
|
||||
num_tokens as i32,
|
||||
dim as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
@@ -175,8 +232,11 @@ pub fn moe_weighted_sum(
|
||||
topk_ids.data_ptr() as *const c_void,
|
||||
topk_weights.data_ptr() as *const c_void,
|
||||
out.data_ptr() as *mut c_void,
|
||||
num_tokens as i32, hidden as i32, top_k as i32,
|
||||
expert_start as i32, local_experts as i32,
|
||||
num_tokens as i32,
|
||||
hidden as i32,
|
||||
top_k as i32,
|
||||
expert_start as i32,
|
||||
local_experts as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
@@ -198,9 +258,16 @@ pub fn moe_weighted_sum(
|
||||
/// consumer must skip them (see moe_weighted_sum_sparse).
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn moe_sparse_gemv_fp8(
|
||||
x: &Tensor, w_fp8_t: &Tensor, w_scales: &Tensor, bias: &Tensor,
|
||||
topk_ids: &Tensor, num_tokens: usize, top_k: usize,
|
||||
expert_start: usize, local_experts: usize, x_per_slot: bool,
|
||||
x: &Tensor,
|
||||
w_fp8_t: &Tensor,
|
||||
w_scales: &Tensor,
|
||||
bias: &Tensor,
|
||||
topk_ids: &Tensor,
|
||||
num_tokens: usize,
|
||||
top_k: usize,
|
||||
expert_start: usize,
|
||||
local_experts: usize,
|
||||
x_per_slot: bool,
|
||||
) -> Tensor {
|
||||
assert_eq!(x.dtype(), DType::BF16);
|
||||
assert!(x.is_contiguous());
|
||||
@@ -211,7 +278,14 @@ pub fn moe_sparse_gemv_fp8(
|
||||
// silently skip a K%16 tail.
|
||||
assert_eq!(k % 16, 0, "sparse FP8 GEMV requires K % 16 == 0, got {k}");
|
||||
assert_eq!(x.shape()[x.ndim() - 1], k);
|
||||
assert_eq!(x.shape()[0], if x_per_slot { num_tokens * top_k } else { num_tokens });
|
||||
assert_eq!(
|
||||
x.shape()[0],
|
||||
if x_per_slot {
|
||||
num_tokens * top_k
|
||||
} else {
|
||||
num_tokens
|
||||
}
|
||||
);
|
||||
|
||||
let y = Tensor::empty(&[num_tokens, top_k, n], DType::BF16, x.device());
|
||||
unsafe {
|
||||
@@ -222,8 +296,13 @@ pub fn moe_sparse_gemv_fp8(
|
||||
bias.data_ptr() as *const c_void,
|
||||
topk_ids.data_ptr() as *const c_void,
|
||||
y.data_ptr() as *mut c_void,
|
||||
num_tokens as i32, n as i32, k as i32, top_k as i32,
|
||||
expert_start as i32, local_experts as i32, x_per_slot as i32,
|
||||
num_tokens as i32,
|
||||
n as i32,
|
||||
k as i32,
|
||||
top_k as i32,
|
||||
expert_start as i32,
|
||||
local_experts as i32,
|
||||
x_per_slot as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
@@ -234,16 +313,32 @@ pub fn moe_sparse_gemv_fp8(
|
||||
/// with packed 4-bit weights [E, N, K/2] + UE8M0 block scales [E, N, K/32].
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn moe_sparse_gemv_mxfp4(
|
||||
x: &Tensor, w_packed: &Tensor, w_scales: &Tensor, bias: &Tensor,
|
||||
topk_ids: &Tensor, num_tokens: usize, top_k: usize, n: usize, k: usize,
|
||||
expert_start: usize, local_experts: usize, x_per_slot: bool,
|
||||
x: &Tensor,
|
||||
w_packed: &Tensor,
|
||||
w_scales: &Tensor,
|
||||
bias: &Tensor,
|
||||
topk_ids: &Tensor,
|
||||
num_tokens: usize,
|
||||
top_k: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
expert_start: usize,
|
||||
local_experts: usize,
|
||||
x_per_slot: bool,
|
||||
) -> Tensor {
|
||||
assert_eq!(x.dtype(), DType::BF16);
|
||||
assert!(x.is_contiguous());
|
||||
// 32-element MXFP4 blocks, read as uint4 (32 nibbles) per lane.
|
||||
assert_eq!(k % 32, 0, "sparse MXFP4 GEMV requires K % 32 == 0, got {k}");
|
||||
assert_eq!(x.shape()[x.ndim() - 1], k);
|
||||
assert_eq!(x.shape()[0], if x_per_slot { num_tokens * top_k } else { num_tokens });
|
||||
assert_eq!(
|
||||
x.shape()[0],
|
||||
if x_per_slot {
|
||||
num_tokens * top_k
|
||||
} else {
|
||||
num_tokens
|
||||
}
|
||||
);
|
||||
|
||||
let y = Tensor::empty(&[num_tokens, top_k, n], DType::BF16, x.device());
|
||||
unsafe {
|
||||
@@ -254,8 +349,13 @@ pub fn moe_sparse_gemv_mxfp4(
|
||||
bias.data_ptr() as *const c_void,
|
||||
topk_ids.data_ptr() as *const c_void,
|
||||
y.data_ptr() as *mut c_void,
|
||||
num_tokens as i32, n as i32, k as i32, top_k as i32,
|
||||
expert_start as i32, local_experts as i32, x_per_slot as i32,
|
||||
num_tokens as i32,
|
||||
n as i32,
|
||||
k as i32,
|
||||
top_k as i32,
|
||||
expert_start as i32,
|
||||
local_experts as i32,
|
||||
x_per_slot as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
@@ -286,8 +386,11 @@ pub fn moe_weighted_sum_sparse(
|
||||
topk_ids.data_ptr() as *const c_void,
|
||||
topk_weights.data_ptr() as *const c_void,
|
||||
out.data_ptr() as *mut c_void,
|
||||
num_tokens as i32, hidden as i32, top_k as i32,
|
||||
expert_start as i32, local_experts as i32,
|
||||
num_tokens as i32,
|
||||
hidden as i32,
|
||||
top_k as i32,
|
||||
expert_start as i32,
|
||||
local_experts as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
@@ -341,13 +444,25 @@ pub fn batched_gemm_strided(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
cublasSetStream_v2(handle, xserv_cuda::current_stream_raw());
|
||||
let status = cublasGemmStridedBatchedEx(
|
||||
handle,
|
||||
0, 0, // CUBLAS_OP_N, CUBLAS_OP_N
|
||||
n as i32, m as i32, k as i32,
|
||||
0,
|
||||
0, // CUBLAS_OP_N, CUBLAS_OP_N
|
||||
n as i32,
|
||||
m as i32,
|
||||
k as i32,
|
||||
&alpha as *const f32 as *const c_void,
|
||||
b.data_ptr() as *const c_void, CUDA_R_16BF, n as i32, stride_b,
|
||||
a.data_ptr() as *const c_void, CUDA_R_16BF, k as i32, stride_a,
|
||||
b.data_ptr() as *const c_void,
|
||||
CUDA_R_16BF,
|
||||
n as i32,
|
||||
stride_b,
|
||||
a.data_ptr() as *const c_void,
|
||||
CUDA_R_16BF,
|
||||
k as i32,
|
||||
stride_a,
|
||||
&beta as *const f32 as *const c_void,
|
||||
c.data_ptr() as *mut c_void, CUDA_R_16BF, n as i32, stride_c,
|
||||
c.data_ptr() as *mut c_void,
|
||||
CUDA_R_16BF,
|
||||
n as i32,
|
||||
stride_c,
|
||||
batch as i32,
|
||||
CUBLAS_COMPUTE_32F,
|
||||
CUBLAS_GEMM_DEFAULT,
|
||||
|
||||
@@ -13,30 +13,46 @@ unsafe extern "C" {
|
||||
src: *const c_void,
|
||||
scales: *const c_void,
|
||||
dst: *mut c_void,
|
||||
num_experts: i32, rows: i32, cols: i32,
|
||||
num_experts: i32,
|
||||
rows: i32,
|
||||
cols: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_quantize_bf16_to_fp8e4m3_rowwise(
|
||||
src: *const c_void,
|
||||
dst: *mut c_void,
|
||||
scales: *mut c_void,
|
||||
num_rows: i32, cols: i32,
|
||||
num_rows: i32,
|
||||
cols: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_rowwise_scale_moe_bf16(
|
||||
data: *mut c_void,
|
||||
a_scales: *const c_void,
|
||||
b_scales: *const c_void,
|
||||
num_rows: i32, cols: i32, tokens: i32,
|
||||
num_rows: i32,
|
||||
cols: i32,
|
||||
tokens: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_batched_gemv_mxfp4_bf16(
|
||||
x: *const c_void, w_packed: *const c_void, w_scales: *const c_void, y: *mut c_void,
|
||||
e: i32, n: i32, k: i32, stream: *mut c_void,
|
||||
x: *const c_void,
|
||||
w_packed: *const c_void,
|
||||
w_scales: *const c_void,
|
||||
y: *mut c_void,
|
||||
e: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_dequant_mxfp4_to_bf16_t(
|
||||
w_packed: *const c_void, w_scales: *const c_void, out: *mut c_void,
|
||||
e: i32, n: i32, k: i32, stream: *mut c_void,
|
||||
w_packed: *const c_void,
|
||||
w_scales: *const c_void,
|
||||
out: *mut c_void,
|
||||
e: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -66,34 +82,68 @@ struct CublasLtMatmulHeuristicResult {
|
||||
unsafe extern "C" {
|
||||
fn cublasLtCreate(handle: *mut CublasLtHandle) -> i32;
|
||||
fn cublasLtDestroy(handle: CublasLtHandle) -> i32;
|
||||
fn cublasLtMatmulDescCreate(desc: *mut CublasLtMatmulDesc, compute_type: i32, scale_type: i32) -> i32;
|
||||
fn cublasLtMatmulDescCreate(
|
||||
desc: *mut CublasLtMatmulDesc,
|
||||
compute_type: i32,
|
||||
scale_type: i32,
|
||||
) -> i32;
|
||||
fn cublasLtMatmulDescDestroy(desc: CublasLtMatmulDesc) -> i32;
|
||||
fn cublasLtMatmulDescSetAttribute(desc: CublasLtMatmulDesc, attr: i32, buf: *const c_void, size: usize) -> i32;
|
||||
fn cublasLtMatrixLayoutCreate(layout: *mut CublasLtMatrixLayout, dtype: i32, rows: u64, cols: u64, ld: i64) -> i32;
|
||||
fn cublasLtMatmulDescSetAttribute(
|
||||
desc: CublasLtMatmulDesc,
|
||||
attr: i32,
|
||||
buf: *const c_void,
|
||||
size: usize,
|
||||
) -> i32;
|
||||
fn cublasLtMatrixLayoutCreate(
|
||||
layout: *mut CublasLtMatrixLayout,
|
||||
dtype: i32,
|
||||
rows: u64,
|
||||
cols: u64,
|
||||
ld: i64,
|
||||
) -> i32;
|
||||
fn cublasLtMatrixLayoutDestroy(layout: CublasLtMatrixLayout) -> i32;
|
||||
fn cublasLtMatrixLayoutSetAttribute(layout: CublasLtMatrixLayout, attr: i32, buf: *const c_void, size: usize) -> i32;
|
||||
fn cublasLtMatrixLayoutSetAttribute(
|
||||
layout: CublasLtMatrixLayout,
|
||||
attr: i32,
|
||||
buf: *const c_void,
|
||||
size: usize,
|
||||
) -> i32;
|
||||
fn cublasLtMatmulPreferenceCreate(pref: *mut CublasLtMatmulPreference) -> i32;
|
||||
fn cublasLtMatmulPreferenceDestroy(pref: CublasLtMatmulPreference) -> i32;
|
||||
fn cublasLtMatmulPreferenceSetAttribute(pref: CublasLtMatmulPreference, attr: i32, buf: *const c_void, size: usize) -> i32;
|
||||
fn cublasLtMatmulPreferenceSetAttribute(
|
||||
pref: CublasLtMatmulPreference,
|
||||
attr: i32,
|
||||
buf: *const c_void,
|
||||
size: usize,
|
||||
) -> i32;
|
||||
fn cublasLtMatmulAlgoGetHeuristic(
|
||||
handle: CublasLtHandle, desc: CublasLtMatmulDesc,
|
||||
a_layout: CublasLtMatrixLayout, b_layout: CublasLtMatrixLayout,
|
||||
c_layout: CublasLtMatrixLayout, d_layout: CublasLtMatrixLayout,
|
||||
handle: CublasLtHandle,
|
||||
desc: CublasLtMatmulDesc,
|
||||
a_layout: CublasLtMatrixLayout,
|
||||
b_layout: CublasLtMatrixLayout,
|
||||
c_layout: CublasLtMatrixLayout,
|
||||
d_layout: CublasLtMatrixLayout,
|
||||
pref: CublasLtMatmulPreference,
|
||||
requested: i32,
|
||||
results: *mut CublasLtMatmulHeuristicResult,
|
||||
found: *mut i32,
|
||||
) -> i32;
|
||||
fn cublasLtMatmul(
|
||||
handle: CublasLtHandle, desc: CublasLtMatmulDesc,
|
||||
handle: CublasLtHandle,
|
||||
desc: CublasLtMatmulDesc,
|
||||
alpha: *const c_void,
|
||||
a: *const c_void, a_layout: CublasLtMatrixLayout,
|
||||
b: *const c_void, b_layout: CublasLtMatrixLayout,
|
||||
a: *const c_void,
|
||||
a_layout: CublasLtMatrixLayout,
|
||||
b: *const c_void,
|
||||
b_layout: CublasLtMatrixLayout,
|
||||
beta: *const c_void,
|
||||
c: *const c_void, c_layout: CublasLtMatrixLayout,
|
||||
d: *mut c_void, d_layout: CublasLtMatrixLayout,
|
||||
c: *const c_void,
|
||||
c_layout: CublasLtMatrixLayout,
|
||||
d: *mut c_void,
|
||||
d_layout: CublasLtMatrixLayout,
|
||||
algo: *const CublasLtMatmulAlgo,
|
||||
workspace: *mut c_void, workspace_size: usize,
|
||||
workspace: *mut c_void,
|
||||
workspace_size: usize,
|
||||
stream: *mut c_void,
|
||||
) -> i32;
|
||||
}
|
||||
@@ -153,8 +203,15 @@ impl CublasLtContext {
|
||||
assert_eq!(status, 0, "cublasLtCreate failed: {status}");
|
||||
let workspace = GpuBuffer::alloc(WORKSPACE_BYTES).expect("alloc cublasLt workspace");
|
||||
let mut one_buf = GpuBuffer::alloc(4).expect("alloc cublasLt fp8 scale");
|
||||
one_buf.copy_from_host(&1.0f32.to_le_bytes()).expect("init fp8 scale");
|
||||
Self { handle, workspace, one_buf, plans: HashMap::new() }
|
||||
one_buf
|
||||
.copy_from_host(&1.0f32.to_le_bytes())
|
||||
.expect("init fp8 scale");
|
||||
Self {
|
||||
handle,
|
||||
workspace,
|
||||
one_buf,
|
||||
plans: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the cached strided-batched plan for (m, n, k, batch), building it on
|
||||
@@ -210,10 +267,25 @@ unsafe fn build_fp8_plan(
|
||||
|
||||
// transA=T (required for FP8 on Blackwell)
|
||||
let trans_a: i32 = 1;
|
||||
cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_a as *const i32 as _, 4);
|
||||
cublasLtMatmulDescSetAttribute(
|
||||
desc,
|
||||
CUBLASLT_MATMUL_DESC_TRANSA,
|
||||
&trans_a as *const i32 as _,
|
||||
4,
|
||||
);
|
||||
let ptr_sz = std::mem::size_of::<*const c_void>();
|
||||
cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &one_ptr as *const _ as _, ptr_sz);
|
||||
cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &one_ptr as *const _ as _, ptr_sz);
|
||||
cublasLtMatmulDescSetAttribute(
|
||||
desc,
|
||||
CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
|
||||
&one_ptr as *const _ as _,
|
||||
ptr_sz,
|
||||
);
|
||||
cublasLtMatmulDescSetAttribute(
|
||||
desc,
|
||||
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
|
||||
&one_ptr as *const _ as _,
|
||||
ptr_sz,
|
||||
);
|
||||
|
||||
// Per-expert strides in ELEMENTS for the strided-batch layout.
|
||||
let stride_a = (n * k) as i64; // weights [N, K]
|
||||
@@ -221,10 +293,18 @@ unsafe fn build_fp8_plan(
|
||||
let stride_c = (m * n) as i64; // output [M, N]
|
||||
let bc = batch as i32;
|
||||
let set_batch = |layout: CublasLtMatrixLayout, stride: i64| {
|
||||
cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
|
||||
&bc as *const i32 as _, 4);
|
||||
cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
|
||||
&stride as *const i64 as _, 8);
|
||||
cublasLtMatrixLayoutSetAttribute(
|
||||
layout,
|
||||
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
|
||||
&bc as *const i32 as _,
|
||||
4,
|
||||
);
|
||||
cublasLtMatrixLayoutSetAttribute(
|
||||
layout,
|
||||
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
|
||||
&stride as *const i64 as _,
|
||||
8,
|
||||
);
|
||||
};
|
||||
|
||||
// "A" layout (weights, transposed): physical (K, N) col-major, ld=K
|
||||
@@ -246,20 +326,39 @@ unsafe fn build_fp8_plan(
|
||||
let mut pref: CublasLtMatmulPreference = std::ptr::null_mut();
|
||||
cublasLtMatmulPreferenceCreate(&mut pref);
|
||||
let ws_bytes = WORKSPACE_BYTES as u64;
|
||||
cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &ws_bytes as *const u64 as _, 8);
|
||||
cublasLtMatmulPreferenceSetAttribute(
|
||||
pref,
|
||||
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||
&ws_bytes as *const u64 as _,
|
||||
8,
|
||||
);
|
||||
|
||||
let mut heuristic = std::mem::zeroed::<CublasLtMatmulHeuristicResult>();
|
||||
let mut found: i32 = 0;
|
||||
let status = cublasLtMatmulAlgoGetHeuristic(
|
||||
handle, desc, a_layout, b_layout, c_layout, d_layout,
|
||||
pref, 1, &mut heuristic, &mut found,
|
||||
handle,
|
||||
desc,
|
||||
a_layout,
|
||||
b_layout,
|
||||
c_layout,
|
||||
d_layout,
|
||||
pref,
|
||||
1,
|
||||
&mut heuristic,
|
||||
&mut found,
|
||||
);
|
||||
assert!(
|
||||
status == 0 && found > 0,
|
||||
"cublasLtMatmulAlgoGetHeuristic failed for batched FP8 GEMM (m={m}, n={n}, k={k}, batch={batch}): status={status}, found={found}"
|
||||
);
|
||||
assert!(status == 0 && found > 0,
|
||||
"cublasLtMatmulAlgoGetHeuristic failed for batched FP8 GEMM (m={m}, n={n}, k={k}, batch={batch}): status={status}, found={found}");
|
||||
cublasLtMatmulPreferenceDestroy(pref);
|
||||
|
||||
Fp8Plan {
|
||||
desc, a_layout, b_layout, c_layout, d_layout,
|
||||
desc,
|
||||
a_layout,
|
||||
b_layout,
|
||||
c_layout,
|
||||
d_layout,
|
||||
algo: heuristic.algo,
|
||||
workspace_size: heuristic.workspace_size,
|
||||
}
|
||||
@@ -299,7 +398,9 @@ pub fn dequant_fp8_to_bf16(src: &Tensor, scales: &Tensor) -> Tensor {
|
||||
src.data_ptr() as *const c_void,
|
||||
scales.data_ptr() as *const c_void,
|
||||
out.data_ptr() as *mut c_void,
|
||||
num_experts as i32, rows as i32, cols as i32,
|
||||
num_experts as i32,
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
@@ -329,7 +430,8 @@ pub fn quantize_bf16_to_fp8_rowwise(src: &Tensor) -> (Tensor, Tensor) {
|
||||
src.data_ptr() as *const c_void,
|
||||
fp8_out.data_ptr() as *mut c_void,
|
||||
scales.data_ptr() as *mut c_void,
|
||||
num_rows as i32, cols as i32,
|
||||
num_rows as i32,
|
||||
cols as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
@@ -392,7 +494,8 @@ pub fn batched_gemm_fp8(
|
||||
|
||||
unsafe {
|
||||
let status = cublasLtMatmul(
|
||||
handle, plan.desc,
|
||||
handle,
|
||||
plan.desc,
|
||||
&alpha as *const f32 as _,
|
||||
b_fp8_t.data_ptr() as *const c_void, // cuBLASLt "A" = weights
|
||||
plan.a_layout,
|
||||
@@ -408,7 +511,10 @@ pub fn batched_gemm_fp8(
|
||||
plan.workspace_size,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
assert_eq!(status, 0, "batched cublasLtMatmul FP8 failed: status={status}");
|
||||
assert_eq!(
|
||||
status, 0,
|
||||
"batched cublasLtMatmul FP8 failed: status={status}"
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -423,7 +529,9 @@ pub fn batched_gemm_fp8(
|
||||
c.data_ptr() as *mut c_void,
|
||||
a_scales.data_ptr() as *const c_void,
|
||||
b_scales.data_ptr() as *const c_void,
|
||||
total_rows, n as i32, m as i32,
|
||||
total_rows,
|
||||
n as i32,
|
||||
m as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
@@ -442,7 +550,13 @@ pub fn batched_gemm_fp8(
|
||||
/// w_scales: [E, N, K/32] byte tensor — UE8M0 scale per 32-element block
|
||||
///
|
||||
/// Returns: [E, N] BF16, where y[e,n] = sum_k x[e,k] * dequant(W[e,n,k]).
|
||||
pub fn batched_gemv_mxfp4(x: &Tensor, w_packed: &Tensor, w_scales: &Tensor, n: usize, k: usize) -> Tensor {
|
||||
pub fn batched_gemv_mxfp4(
|
||||
x: &Tensor,
|
||||
w_packed: &Tensor,
|
||||
w_scales: &Tensor,
|
||||
n: usize,
|
||||
k: usize,
|
||||
) -> Tensor {
|
||||
assert_eq!(x.dtype(), DType::BF16);
|
||||
assert!(x.is_contiguous());
|
||||
let e = x.shape()[0];
|
||||
@@ -455,7 +569,9 @@ pub fn batched_gemv_mxfp4(x: &Tensor, w_packed: &Tensor, w_scales: &Tensor, n: u
|
||||
w_packed.data_ptr() as *const c_void,
|
||||
w_scales.data_ptr() as *const c_void,
|
||||
y.data_ptr() as *mut c_void,
|
||||
e as i32, n as i32, k as i32,
|
||||
e as i32,
|
||||
n as i32,
|
||||
k as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
@@ -464,14 +580,22 @@ pub fn batched_gemv_mxfp4(x: &Tensor, w_packed: &Tensor, w_scales: &Tensor, n: u
|
||||
|
||||
/// Dequantize MXFP4 weights [E, N, K] → BF16 [E, K, N] for the prefill GEMM path
|
||||
/// (the BF16 batched GEMM expects weights as [E, K, N]).
|
||||
pub fn dequant_mxfp4_to_bf16_t(w_packed: &Tensor, w_scales: &Tensor, e: usize, n: usize, k: usize) -> Tensor {
|
||||
pub fn dequant_mxfp4_to_bf16_t(
|
||||
w_packed: &Tensor,
|
||||
w_scales: &Tensor,
|
||||
e: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
) -> Tensor {
|
||||
let out = Tensor::empty(&[e, k, n], DType::BF16, w_packed.device());
|
||||
unsafe {
|
||||
launch_dequant_mxfp4_to_bf16_t(
|
||||
w_packed.data_ptr() as *const c_void,
|
||||
w_scales.data_ptr() as *const c_void,
|
||||
out.data_ptr() as *mut c_void,
|
||||
e as i32, n as i32, k as i32,
|
||||
e as i32,
|
||||
n as i32,
|
||||
k as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -2,13 +2,35 @@ use std::ffi::c_void;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_rmsnorm_f32(x: *const c_void, gamma: *const c_void, out: *mut c_void,
|
||||
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
|
||||
fn launch_rmsnorm_bf16(x: *const c_void, gamma: *const c_void, out: *mut c_void,
|
||||
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
|
||||
fn launch_add_rmsnorm_bf16(x: *const c_void, residual: *const c_void, gamma: *const c_void,
|
||||
normed_out: *mut c_void, sum_out: *mut c_void,
|
||||
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
|
||||
fn launch_rmsnorm_f32(
|
||||
x: *const c_void,
|
||||
gamma: *const c_void,
|
||||
out: *mut c_void,
|
||||
rows: i32,
|
||||
hidden_size: i32,
|
||||
eps: f32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_rmsnorm_bf16(
|
||||
x: *const c_void,
|
||||
gamma: *const c_void,
|
||||
out: *mut c_void,
|
||||
rows: i32,
|
||||
hidden_size: i32,
|
||||
eps: f32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_add_rmsnorm_bf16(
|
||||
x: *const c_void,
|
||||
residual: *const c_void,
|
||||
gamma: *const c_void,
|
||||
normed_out: *mut c_void,
|
||||
sum_out: *mut c_void,
|
||||
rows: i32,
|
||||
hidden_size: i32,
|
||||
eps: f32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
pub fn rmsnorm(x: &Tensor, gamma: &Tensor, eps: f32) -> Tensor {
|
||||
@@ -20,19 +42,35 @@ pub fn rmsnorm(x: &Tensor, gamma: &Tensor, eps: f32) -> Tensor {
|
||||
assert_eq!(x.dtype(), gamma.dtype());
|
||||
|
||||
let rows = x.numel() / hidden_size;
|
||||
assert!(rows <= i32::MAX as usize, "too many rows for i32 kernel param");
|
||||
assert!(hidden_size <= i32::MAX as usize, "hidden_size too large for i32 kernel param");
|
||||
assert!(
|
||||
rows <= i32::MAX as usize,
|
||||
"too many rows for i32 kernel param"
|
||||
);
|
||||
assert!(
|
||||
hidden_size <= i32::MAX as usize,
|
||||
"hidden_size too large for i32 kernel param"
|
||||
);
|
||||
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
||||
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => launch_rmsnorm_f32(
|
||||
x.data_ptr() as _, gamma.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
rows as i32, hidden_size as i32, eps, xserv_cuda::current_stream_raw(),
|
||||
x.data_ptr() as _,
|
||||
gamma.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
rows as i32,
|
||||
hidden_size as i32,
|
||||
eps,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::BF16 => launch_rmsnorm_bf16(
|
||||
x.data_ptr() as _, gamma.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
rows as i32, hidden_size as i32, eps, xserv_cuda::current_stream_raw(),
|
||||
x.data_ptr() as _,
|
||||
gamma.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
rows as i32,
|
||||
hidden_size as i32,
|
||||
eps,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for rmsnorm"),
|
||||
}
|
||||
@@ -56,8 +94,14 @@ pub fn add_rmsnorm(x: &Tensor, residual: &Tensor, gamma: &Tensor, eps: f32) -> (
|
||||
assert_eq!(gamma.shape(), &[hidden_size]);
|
||||
|
||||
let rows = x.numel() / hidden_size;
|
||||
assert!(rows <= i32::MAX as usize, "too many rows for i32 kernel param");
|
||||
assert!(hidden_size <= i32::MAX as usize, "hidden_size too large for i32 kernel param");
|
||||
assert!(
|
||||
rows <= i32::MAX as usize,
|
||||
"too many rows for i32 kernel param"
|
||||
);
|
||||
assert!(
|
||||
hidden_size <= i32::MAX as usize,
|
||||
"hidden_size too large for i32 kernel param"
|
||||
);
|
||||
let normed_out = Tensor::empty(x.shape(), DType::BF16, x.device());
|
||||
let sum_out = Tensor::empty(x.shape(), DType::BF16, x.device());
|
||||
|
||||
|
||||
@@ -3,15 +3,34 @@ use xserv_cuda::GpuBuffer;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_rope_f32(x: *mut c_void, cos_cache: *const c_void, sin_cache: *const c_void,
|
||||
positions: *const c_void, num_tokens: i32, num_heads: i32,
|
||||
head_dim: i32, stream: *mut c_void);
|
||||
fn launch_rope_bf16(x: *mut c_void, cos_cache: *const c_void, sin_cache: *const c_void,
|
||||
positions: *const c_void, num_tokens: i32, num_heads: i32,
|
||||
head_dim: i32, stream: *mut c_void);
|
||||
fn launch_compute_rope_cache(cos_cache: *mut c_void, sin_cache: *mut c_void,
|
||||
max_seq_len: i32, half_dim: i32, theta: f32,
|
||||
stream: *mut c_void);
|
||||
fn launch_rope_f32(
|
||||
x: *mut c_void,
|
||||
cos_cache: *const c_void,
|
||||
sin_cache: *const c_void,
|
||||
positions: *const c_void,
|
||||
num_tokens: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_rope_bf16(
|
||||
x: *mut c_void,
|
||||
cos_cache: *const c_void,
|
||||
sin_cache: *const c_void,
|
||||
positions: *const c_void,
|
||||
num_tokens: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_compute_rope_cache(
|
||||
cos_cache: *mut c_void,
|
||||
sin_cache: *mut c_void,
|
||||
max_seq_len: i32,
|
||||
half_dim: i32,
|
||||
theta: f32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
pub struct RopeCache {
|
||||
@@ -30,12 +49,21 @@ impl RopeCache {
|
||||
|
||||
unsafe {
|
||||
launch_compute_rope_cache(
|
||||
cos.as_mut_ptr() as _, sin.as_mut_ptr() as _,
|
||||
max_seq_len as i32, half_dim as i32, theta, xserv_cuda::current_stream_raw(),
|
||||
cos.as_mut_ptr() as _,
|
||||
sin.as_mut_ptr() as _,
|
||||
max_seq_len as i32,
|
||||
half_dim as i32,
|
||||
theta,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
Self { cos, sin, max_seq_len, half_dim }
|
||||
Self {
|
||||
cos,
|
||||
sin,
|
||||
max_seq_len,
|
||||
half_dim,
|
||||
}
|
||||
}
|
||||
|
||||
/// YaRN (Yet another RoPE extensioN) RoPE cache. Applies frequency-dependent
|
||||
@@ -101,16 +129,19 @@ impl RopeCache {
|
||||
let nbytes = total * std::mem::size_of::<f32>();
|
||||
let mut cos = GpuBuffer::alloc(nbytes).expect("alloc yarn cos_cache");
|
||||
let mut sin = GpuBuffer::alloc(nbytes).expect("alloc yarn sin_cache");
|
||||
let cos_bytes = unsafe {
|
||||
std::slice::from_raw_parts(cos_host.as_ptr() as *const u8, nbytes)
|
||||
};
|
||||
let sin_bytes = unsafe {
|
||||
std::slice::from_raw_parts(sin_host.as_ptr() as *const u8, nbytes)
|
||||
};
|
||||
let cos_bytes =
|
||||
unsafe { std::slice::from_raw_parts(cos_host.as_ptr() as *const u8, nbytes) };
|
||||
let sin_bytes =
|
||||
unsafe { std::slice::from_raw_parts(sin_host.as_ptr() as *const u8, nbytes) };
|
||||
cos.copy_from_host(cos_bytes).unwrap();
|
||||
sin.copy_from_host(sin_bytes).unwrap();
|
||||
|
||||
Self { cos, sin, max_seq_len, half_dim }
|
||||
Self {
|
||||
cos,
|
||||
sin,
|
||||
max_seq_len,
|
||||
half_dim,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -133,7 +164,8 @@ pub fn rope_inplace(x: &Tensor, cache: &RopeCache, positions: &[u32]) {
|
||||
num_tokens * std::mem::size_of::<u32>(),
|
||||
)
|
||||
};
|
||||
let mut pos_gpu = xserv_cuda::allocator::cached_alloc(pos_bytes.len()).expect("alloc positions");
|
||||
let mut pos_gpu =
|
||||
xserv_cuda::allocator::cached_alloc(pos_bytes.len()).expect("alloc positions");
|
||||
pos_gpu.copy_from_host(pos_bytes).unwrap();
|
||||
|
||||
rope_inplace_device_pos(x, cache, pos_gpu.as_ptr() as *const c_void);
|
||||
@@ -155,16 +187,22 @@ pub fn rope_inplace_device_pos(x: &Tensor, cache: &RopeCache, pos_gpu: *const c_
|
||||
match x.dtype() {
|
||||
DType::F32 => launch_rope_f32(
|
||||
x.data_ptr() as *mut c_void,
|
||||
cache.cos.as_ptr() as _, cache.sin.as_ptr() as _,
|
||||
cache.cos.as_ptr() as _,
|
||||
cache.sin.as_ptr() as _,
|
||||
pos_gpu,
|
||||
num_tokens as i32, num_heads as i32, head_dim as i32,
|
||||
num_tokens as i32,
|
||||
num_heads as i32,
|
||||
head_dim as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::BF16 => launch_rope_bf16(
|
||||
x.data_ptr() as *mut c_void,
|
||||
cache.cos.as_ptr() as _, cache.sin.as_ptr() as _,
|
||||
cache.cos.as_ptr() as _,
|
||||
cache.sin.as_ptr() as _,
|
||||
pos_gpu,
|
||||
num_tokens as i32, num_heads as i32, head_dim as i32,
|
||||
num_tokens as i32,
|
||||
num_heads as i32,
|
||||
head_dim as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for rope"),
|
||||
|
||||
@@ -2,8 +2,20 @@ use std::ffi::c_void;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_softmax_f32(x: *const c_void, out: *mut c_void, rows: i32, cols: i32, stream: *mut c_void);
|
||||
fn launch_softmax_bf16(x: *const c_void, out: *mut c_void, rows: i32, cols: i32, stream: *mut c_void);
|
||||
fn launch_softmax_f32(
|
||||
x: *const c_void,
|
||||
out: *mut c_void,
|
||||
rows: i32,
|
||||
cols: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_softmax_bf16(
|
||||
x: *const c_void,
|
||||
out: *mut c_void,
|
||||
rows: i32,
|
||||
cols: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
/// Softmax along the last dimension.
|
||||
@@ -14,19 +26,31 @@ pub fn softmax(x: &Tensor) -> Tensor {
|
||||
|
||||
let cols = *x.shape().last().unwrap();
|
||||
let rows = x.numel() / cols;
|
||||
assert!(rows <= i32::MAX as usize, "too many rows for i32 kernel param");
|
||||
assert!(cols <= i32::MAX as usize, "cols too large for i32 kernel param");
|
||||
assert!(
|
||||
rows <= i32::MAX as usize,
|
||||
"too many rows for i32 kernel param"
|
||||
);
|
||||
assert!(
|
||||
cols <= i32::MAX as usize,
|
||||
"cols too large for i32 kernel param"
|
||||
);
|
||||
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
||||
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => launch_softmax_f32(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
rows as i32, cols as i32, xserv_cuda::current_stream_raw(),
|
||||
x.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::BF16 => launch_softmax_bf16(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
rows as i32, cols as i32, xserv_cuda::current_stream_raw(),
|
||||
x.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for softmax"),
|
||||
}
|
||||
|
||||
@@ -2,19 +2,79 @@ use std::ffi::c_void;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_reshape_heads_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
|
||||
fn launch_merge_heads_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
|
||||
fn launch_transpose_hsd_to_shd_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
|
||||
fn launch_transpose_shd_to_hsd_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
|
||||
fn launch_repeat_kv_bf16(inp: *const c_void, out: *mut c_void, kv_heads: i32, n_rep: i32, seq_len: i32, head_dim: i32, stream: *mut c_void);
|
||||
fn launch_strided_copy_bf16(inp: *const c_void, out: *mut c_void, numel: i32, ndim: i32,
|
||||
shape0: i32, shape1: i32, shape2: i32, shape3: i32,
|
||||
in_stride0: i32, in_stride1: i32, in_stride2: i32, in_stride3: i32,
|
||||
in_offset: i32, stream: *mut c_void);
|
||||
fn launch_strided_copy_f32(inp: *const c_void, out: *mut c_void, numel: i32, ndim: i32,
|
||||
shape0: i32, shape1: i32, shape2: i32, shape3: i32,
|
||||
in_stride0: i32, in_stride1: i32, in_stride2: i32, in_stride3: i32,
|
||||
in_offset: i32, stream: *mut c_void);
|
||||
fn launch_reshape_heads_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_merge_heads_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_transpose_hsd_to_shd_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_transpose_shd_to_hsd_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_repeat_kv_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
kv_heads: i32,
|
||||
n_rep: i32,
|
||||
seq_len: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_strided_copy_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
numel: i32,
|
||||
ndim: i32,
|
||||
shape0: i32,
|
||||
shape1: i32,
|
||||
shape2: i32,
|
||||
shape3: i32,
|
||||
in_stride0: i32,
|
||||
in_stride1: i32,
|
||||
in_stride2: i32,
|
||||
in_stride3: i32,
|
||||
in_offset: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_strided_copy_f32(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
numel: i32,
|
||||
ndim: i32,
|
||||
shape0: i32,
|
||||
shape1: i32,
|
||||
shape2: i32,
|
||||
shape3: i32,
|
||||
in_stride0: i32,
|
||||
in_stride1: i32,
|
||||
in_stride2: i32,
|
||||
in_stride3: i32,
|
||||
in_offset: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
/// [S, H*D] → [1, H, S, D] on GPU (BF16)
|
||||
@@ -24,8 +84,12 @@ pub fn reshape_heads_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim:
|
||||
let out = Tensor::empty(&[1, num_heads, seq_len, head_dim], DType::BF16, x.device());
|
||||
unsafe {
|
||||
launch_reshape_heads_bf16(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
seq_len as i32, num_heads as i32, head_dim as i32, xserv_cuda::current_stream_raw(),
|
||||
x.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
seq_len as i32,
|
||||
num_heads as i32,
|
||||
head_dim as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
out
|
||||
@@ -39,36 +103,58 @@ pub fn merge_heads_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: u
|
||||
let out = Tensor::empty(&[seq_len, hidden], DType::BF16, x.device());
|
||||
unsafe {
|
||||
launch_merge_heads_bf16(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
seq_len as i32, num_heads as i32, head_dim as i32, xserv_cuda::current_stream_raw(),
|
||||
x.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
seq_len as i32,
|
||||
num_heads as i32,
|
||||
head_dim as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// [1, H, S, D] → [S, H, D] for RoPE on GPU (BF16)
|
||||
pub fn transpose_for_rope_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: usize) -> Tensor {
|
||||
pub fn transpose_for_rope_gpu(
|
||||
x: &Tensor,
|
||||
seq_len: usize,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
) -> Tensor {
|
||||
assert_eq!(x.dtype(), DType::BF16);
|
||||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||||
let out = Tensor::empty(&[seq_len, num_heads, head_dim], DType::BF16, x.device());
|
||||
unsafe {
|
||||
launch_transpose_hsd_to_shd_bf16(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
seq_len as i32, num_heads as i32, head_dim as i32, xserv_cuda::current_stream_raw(),
|
||||
x.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
seq_len as i32,
|
||||
num_heads as i32,
|
||||
head_dim as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// [S, H, D] → [1, H, S, D] after RoPE on GPU (BF16)
|
||||
pub fn transpose_from_rope_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: usize) -> Tensor {
|
||||
pub fn transpose_from_rope_gpu(
|
||||
x: &Tensor,
|
||||
seq_len: usize,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
) -> Tensor {
|
||||
assert_eq!(x.dtype(), DType::BF16);
|
||||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||||
let out = Tensor::empty(&[1, num_heads, seq_len, head_dim], DType::BF16, x.device());
|
||||
unsafe {
|
||||
launch_transpose_shd_to_hsd_bf16(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
seq_len as i32, num_heads as i32, head_dim as i32, xserv_cuda::current_stream_raw(),
|
||||
x.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
seq_len as i32,
|
||||
num_heads as i32,
|
||||
head_dim as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
out
|
||||
@@ -76,7 +162,9 @@ pub fn transpose_from_rope_gpu(x: &Tensor, seq_len: usize, num_heads: usize, hea
|
||||
|
||||
/// [1, KV_H, S, D] → [1, KV_H*n_rep, S, D] on GPU (BF16)
|
||||
pub fn repeat_kv_gpu(x: &Tensor, n_rep: usize) -> Tensor {
|
||||
if n_rep == 1 { return x.clone(); }
|
||||
if n_rep == 1 {
|
||||
return x.clone();
|
||||
}
|
||||
assert_eq!(x.dtype(), DType::BF16);
|
||||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||||
let kv_heads = x.shape()[1];
|
||||
@@ -86,8 +174,13 @@ pub fn repeat_kv_gpu(x: &Tensor, n_rep: usize) -> Tensor {
|
||||
let out = Tensor::empty(&[1, new_heads, seq_len, head_dim], DType::BF16, x.device());
|
||||
unsafe {
|
||||
launch_repeat_kv_bf16(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
kv_heads as i32, n_rep as i32, seq_len as i32, head_dim as i32, xserv_cuda::current_stream_raw(),
|
||||
x.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
kv_heads as i32,
|
||||
n_rep as i32,
|
||||
seq_len as i32,
|
||||
head_dim as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
out
|
||||
@@ -122,20 +215,41 @@ pub fn strided_to_contiguous_gpu(x: &Tensor) -> Tensor {
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::BF16 => launch_strided_copy_bf16(
|
||||
storage_ptr as _, out.data_ptr() as *mut c_void,
|
||||
numel as i32, ndim as i32,
|
||||
shape4[0], shape4[1], shape4[2], shape4[3],
|
||||
strides4[0], strides4[1], strides4[2], strides4[3],
|
||||
in_offset, xserv_cuda::current_stream_raw(),
|
||||
storage_ptr as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
numel as i32,
|
||||
ndim as i32,
|
||||
shape4[0],
|
||||
shape4[1],
|
||||
shape4[2],
|
||||
shape4[3],
|
||||
strides4[0],
|
||||
strides4[1],
|
||||
strides4[2],
|
||||
strides4[3],
|
||||
in_offset,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::F32 => launch_strided_copy_f32(
|
||||
storage_ptr as _, out.data_ptr() as *mut c_void,
|
||||
numel as i32, ndim as i32,
|
||||
shape4[0], shape4[1], shape4[2], shape4[3],
|
||||
strides4[0], strides4[1], strides4[2], strides4[3],
|
||||
in_offset, xserv_cuda::current_stream_raw(),
|
||||
storage_ptr as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
numel as i32,
|
||||
ndim as i32,
|
||||
shape4[0],
|
||||
shape4[1],
|
||||
shape4[2],
|
||||
shape4[3],
|
||||
strides4[0],
|
||||
strides4[1],
|
||||
strides4[2],
|
||||
strides4[3],
|
||||
in_offset,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!(
|
||||
"strided_to_contiguous_gpu: unsupported dtype {:?}",
|
||||
x.dtype()
|
||||
),
|
||||
_ => panic!("strided_to_contiguous_gpu: unsupported dtype {:?}", x.dtype()),
|
||||
}
|
||||
}
|
||||
out
|
||||
|
||||
@@ -1,11 +1,21 @@
|
||||
use xserv_kernels::*;
|
||||
use xserv_tensor::{Device, Tensor};
|
||||
|
||||
fn init() { xserv_cuda::device::set_device(0).unwrap(); }
|
||||
fn init() {
|
||||
xserv_cuda::device::set_device(0).unwrap();
|
||||
}
|
||||
|
||||
fn cpu_attention(q: &[f32], k: &[f32], v: &[f32],
|
||||
batch: usize, heads: usize, q_len: usize, kv_len: usize, head_dim: usize,
|
||||
causal: bool) -> Vec<f32> {
|
||||
fn cpu_attention(
|
||||
q: &[f32],
|
||||
k: &[f32],
|
||||
v: &[f32],
|
||||
batch: usize,
|
||||
heads: usize,
|
||||
q_len: usize,
|
||||
kv_len: usize,
|
||||
head_dim: usize,
|
||||
causal: bool,
|
||||
) -> Vec<f32> {
|
||||
let mut out = vec![0.0f32; batch * heads * q_len * head_dim];
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
|
||||
@@ -70,8 +80,13 @@ fn check_close(a: &[f32], b: &[f32], atol: f32, name: &str) {
|
||||
let mut max_err = 0.0f32;
|
||||
for (i, (x, y)) in a.iter().zip(b).enumerate() {
|
||||
let err = (x - y).abs();
|
||||
if err > max_err { max_err = err; }
|
||||
assert!(err <= atol, "{name}: mismatch at [{i}]: got {x}, expected {y}, err {err}");
|
||||
if err > max_err {
|
||||
max_err = err;
|
||||
}
|
||||
assert!(
|
||||
err <= atol,
|
||||
"{name}: mismatch at [{i}]: got {x}, expected {y}, err {err}"
|
||||
);
|
||||
}
|
||||
println!("{name}: max_err = {max_err:.6e}");
|
||||
}
|
||||
@@ -105,7 +120,9 @@ fn test_batched_matmul() {
|
||||
for i in 0..m {
|
||||
for j in 0..n {
|
||||
let mut s = 0.0f32;
|
||||
for kk in 0..k { s += a_cpu[i * k + kk] * b_cpu[kk * n + j]; }
|
||||
for kk in 0..k {
|
||||
s += a_cpu[i * k + kk] * b_cpu[kk * n + j];
|
||||
}
|
||||
expected[i * n + j] = s;
|
||||
}
|
||||
}
|
||||
@@ -116,7 +133,10 @@ fn test_batched_matmul() {
|
||||
#[test]
|
||||
fn test_attention_no_causal() {
|
||||
init();
|
||||
let b = 1; let h = 2; let s = 8; let d = 16;
|
||||
let b = 1;
|
||||
let h = 2;
|
||||
let s = 8;
|
||||
let d = 16;
|
||||
let q_data = make_data(b * h * s * d);
|
||||
let k_data = make_data(b * h * s * d);
|
||||
let v_data = make_data(b * h * s * d);
|
||||
@@ -126,13 +146,21 @@ fn test_attention_no_causal() {
|
||||
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let out = attention(&q, &k, &v, false).to_device(Device::Cpu);
|
||||
check_close(out.as_slice::<f32>(), &expected, 1e-4, "attention_no_causal");
|
||||
check_close(
|
||||
out.as_slice::<f32>(),
|
||||
&expected,
|
||||
1e-4,
|
||||
"attention_no_causal",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attention_causal() {
|
||||
init();
|
||||
let b = 1; let h = 2; let s = 16; let d = 32;
|
||||
let b = 1;
|
||||
let h = 2;
|
||||
let s = 16;
|
||||
let d = 32;
|
||||
let q_data = make_data(b * h * s * d);
|
||||
let k_data = make_data(b * h * s * d);
|
||||
let v_data = make_data(b * h * s * d);
|
||||
@@ -148,7 +176,10 @@ fn test_attention_causal() {
|
||||
#[test]
|
||||
fn test_attention_causal_larger() {
|
||||
init();
|
||||
let b = 2; let h = 4; let s = 64; let d = 64;
|
||||
let b = 2;
|
||||
let h = 4;
|
||||
let s = 64;
|
||||
let d = 64;
|
||||
let q_data = make_data(b * h * s * d);
|
||||
let k_data = make_data(b * h * s * d);
|
||||
let v_data = make_data(b * h * s * d);
|
||||
@@ -158,18 +189,28 @@ fn test_attention_causal_larger() {
|
||||
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let out = attention(&q, &k, &v, true).to_device(Device::Cpu);
|
||||
check_close(out.as_slice::<f32>(), &expected, 1e-2, "attention_causal_larger");
|
||||
check_close(
|
||||
out.as_slice::<f32>(),
|
||||
&expected,
|
||||
1e-2,
|
||||
"attention_causal_larger",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attention_causal_first_row_sees_only_first_token() {
|
||||
init();
|
||||
let b = 1; let h = 1; let s = 4; let d = 8;
|
||||
let b = 1;
|
||||
let h = 1;
|
||||
let s = 4;
|
||||
let d = 8;
|
||||
let q_data = make_data(b * h * s * d);
|
||||
let k_data = make_data(b * h * s * d);
|
||||
let v_data: Vec<f32> = (0..s * d).map(|i| {
|
||||
let v_data: Vec<f32> = (0..s * d)
|
||||
.map(|i| {
|
||||
if i < d { 1.0 } else { 0.0 } // only first V row is nonzero
|
||||
}).collect();
|
||||
})
|
||||
.collect();
|
||||
|
||||
let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
@@ -181,7 +222,11 @@ fn test_attention_causal_first_row_sees_only_first_token() {
|
||||
// output[0] should be exactly V[0] = [1, 1, 1, ...1]
|
||||
let result = out.as_slice::<f32>();
|
||||
for i in 0..d {
|
||||
assert!((result[i] - 1.0).abs() < 1e-5,
|
||||
"first row should equal V[0], got {} at dim {}", result[i], i);
|
||||
assert!(
|
||||
(result[i] - 1.0).abs() < 1e-5,
|
||||
"first row should equal V[0], got {} at dim {}",
|
||||
result[i],
|
||||
i
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use half::bf16;
|
||||
use xserv_kernels::{matmul, GemmBackend};
|
||||
use xserv_kernels::{GemmBackend, matmul};
|
||||
use xserv_tensor::{Device, Tensor};
|
||||
|
||||
fn cpu_matmul_f32(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Vec<f32> {
|
||||
@@ -75,70 +75,110 @@ fn run_gemm_test_bf16(backend: GemmBackend, m: usize, n: usize, k: usize) {
|
||||
// --- F32 tests ---
|
||||
|
||||
#[test]
|
||||
fn test_gemm_naive_f32_small() { run_gemm_test_f32(GemmBackend::Naive, 4, 4, 4); }
|
||||
fn test_gemm_naive_f32_small() {
|
||||
run_gemm_test_f32(GemmBackend::Naive, 4, 4, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_naive_f32_medium() { run_gemm_test_f32(GemmBackend::Naive, 64, 64, 64); }
|
||||
fn test_gemm_naive_f32_medium() {
|
||||
run_gemm_test_f32(GemmBackend::Naive, 64, 64, 64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_naive_f32_rect() { run_gemm_test_f32(GemmBackend::Naive, 32, 64, 48); }
|
||||
fn test_gemm_naive_f32_rect() {
|
||||
run_gemm_test_f32(GemmBackend::Naive, 32, 64, 48);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_tiled_f32_small() { run_gemm_test_f32(GemmBackend::Tiled, 4, 4, 4); }
|
||||
fn test_gemm_tiled_f32_small() {
|
||||
run_gemm_test_f32(GemmBackend::Tiled, 4, 4, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_tiled_f32_medium() { run_gemm_test_f32(GemmBackend::Tiled, 128, 128, 128); }
|
||||
fn test_gemm_tiled_f32_medium() {
|
||||
run_gemm_test_f32(GemmBackend::Tiled, 128, 128, 128);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_tiled_f32_rect() { run_gemm_test_f32(GemmBackend::Tiled, 65, 33, 97); }
|
||||
fn test_gemm_tiled_f32_rect() {
|
||||
run_gemm_test_f32(GemmBackend::Tiled, 65, 33, 97);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_cublas_f32_small() { run_gemm_test_f32(GemmBackend::CuBlas, 4, 4, 4); }
|
||||
fn test_gemm_cublas_f32_small() {
|
||||
run_gemm_test_f32(GemmBackend::CuBlas, 4, 4, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_cublas_f32_medium() { run_gemm_test_f32(GemmBackend::CuBlas, 256, 256, 256); }
|
||||
fn test_gemm_cublas_f32_medium() {
|
||||
run_gemm_test_f32(GemmBackend::CuBlas, 256, 256, 256);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_cublas_f32_rect() { run_gemm_test_f32(GemmBackend::CuBlas, 65, 33, 97); }
|
||||
fn test_gemm_cublas_f32_rect() {
|
||||
run_gemm_test_f32(GemmBackend::CuBlas, 65, 33, 97);
|
||||
}
|
||||
|
||||
// --- BF16 tests ---
|
||||
|
||||
#[test]
|
||||
fn test_gemm_naive_bf16_small() { run_gemm_test_bf16(GemmBackend::Naive, 4, 4, 4); }
|
||||
fn test_gemm_naive_bf16_small() {
|
||||
run_gemm_test_bf16(GemmBackend::Naive, 4, 4, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_naive_bf16_medium() { run_gemm_test_bf16(GemmBackend::Naive, 64, 64, 64); }
|
||||
fn test_gemm_naive_bf16_medium() {
|
||||
run_gemm_test_bf16(GemmBackend::Naive, 64, 64, 64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_tiled_bf16_small() { run_gemm_test_bf16(GemmBackend::Tiled, 4, 4, 4); }
|
||||
fn test_gemm_tiled_bf16_small() {
|
||||
run_gemm_test_bf16(GemmBackend::Tiled, 4, 4, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_tiled_bf16_medium() { run_gemm_test_bf16(GemmBackend::Tiled, 128, 128, 128); }
|
||||
fn test_gemm_tiled_bf16_medium() {
|
||||
run_gemm_test_bf16(GemmBackend::Tiled, 128, 128, 128);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_cublas_bf16_small() { run_gemm_test_bf16(GemmBackend::CuBlas, 4, 4, 4); }
|
||||
fn test_gemm_cublas_bf16_small() {
|
||||
run_gemm_test_bf16(GemmBackend::CuBlas, 4, 4, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_cublas_bf16_medium() { run_gemm_test_bf16(GemmBackend::CuBlas, 256, 256, 256); }
|
||||
fn test_gemm_cublas_bf16_medium() {
|
||||
run_gemm_test_bf16(GemmBackend::CuBlas, 256, 256, 256);
|
||||
}
|
||||
|
||||
// --- Custom GEMV tests (M=1, BF16 fast path) ---
|
||||
|
||||
#[test]
|
||||
fn test_gemv_bf16_small() { run_gemm_test_bf16(GemmBackend::CuBlas, 1, 64, 64); }
|
||||
fn test_gemv_bf16_small() {
|
||||
run_gemm_test_bf16(GemmBackend::CuBlas, 1, 64, 64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemv_bf16_medium() { run_gemm_test_bf16(GemmBackend::CuBlas, 1, 256, 256); }
|
||||
fn test_gemv_bf16_medium() {
|
||||
run_gemm_test_bf16(GemmBackend::CuBlas, 1, 256, 256);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemv_bf16_4096() { run_gemm_test_bf16(GemmBackend::CuBlas, 1, 4096, 4096); }
|
||||
fn test_gemv_bf16_4096() {
|
||||
run_gemm_test_bf16(GemmBackend::CuBlas, 1, 4096, 4096);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemv_bf16_rect() { run_gemm_test_bf16(GemmBackend::CuBlas, 1, 512, 4096); }
|
||||
fn test_gemv_bf16_rect() {
|
||||
run_gemm_test_bf16(GemmBackend::CuBlas, 1, 512, 4096);
|
||||
}
|
||||
|
||||
// --- Larger benchmark-style tests ---
|
||||
|
||||
#[test]
|
||||
fn test_gemm_cublas_f32_1024() { run_gemm_test_f32(GemmBackend::CuBlas, 1024, 1024, 1024); }
|
||||
fn test_gemm_cublas_f32_1024() {
|
||||
run_gemm_test_f32(GemmBackend::CuBlas, 1024, 1024, 1024);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_consistency_all_backends() {
|
||||
|
||||
@@ -2,7 +2,9 @@ use half::bf16;
|
||||
use xserv_kernels::*;
|
||||
use xserv_tensor::{Device, Tensor};
|
||||
|
||||
fn init() { xserv_cuda::device::set_device(0).unwrap(); }
|
||||
fn init() {
|
||||
xserv_cuda::device::set_device(0).unwrap();
|
||||
}
|
||||
|
||||
// --- CPU reference implementations ---
|
||||
|
||||
@@ -37,10 +39,12 @@ fn cpu_layernorm(x: &[f32], gamma: &[f32], beta: &[f32], eps: f32, hidden: usize
|
||||
|
||||
fn cpu_gelu(x: &[f32]) -> Vec<f32> {
|
||||
let sqrt_2_over_pi = 0.7978845608f32;
|
||||
x.iter().map(|&v| {
|
||||
x.iter()
|
||||
.map(|&v| {
|
||||
let inner = sqrt_2_over_pi * (v + 0.044715 * v * v * v);
|
||||
0.5 * v * (1.0 + inner.tanh())
|
||||
}).collect()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn cpu_silu(x: &[f32]) -> Vec<f32> {
|
||||
@@ -88,8 +92,13 @@ fn check_close(result: &[f32], expected: &[f32], atol: f32, name: &str) {
|
||||
let mut max_err = 0.0f32;
|
||||
for (i, (r, e)) in result.iter().zip(expected).enumerate() {
|
||||
let err = (r - e).abs();
|
||||
if err > max_err { max_err = err; }
|
||||
assert!(err <= atol, "{name}: mismatch at [{i}]: got {r}, expected {e}, err {err}");
|
||||
if err > max_err {
|
||||
max_err = err;
|
||||
}
|
||||
assert!(
|
||||
err <= atol,
|
||||
"{name}: mismatch at [{i}]: got {r}, expected {e}, err {err}"
|
||||
);
|
||||
}
|
||||
println!("{name}: max_err = {max_err:.6e}");
|
||||
}
|
||||
@@ -208,13 +217,18 @@ fn test_softmax_sum_to_one() {
|
||||
init();
|
||||
let rows = 4;
|
||||
let cols = 2048;
|
||||
let data: Vec<f32> = (0..rows * cols).map(|i| ((i % 31) as f32 - 15.0) * 0.5).collect();
|
||||
let data: Vec<f32> = (0..rows * cols)
|
||||
.map(|i| ((i % 31) as f32 - 15.0) * 0.5)
|
||||
.collect();
|
||||
let x = Tensor::from_slice(&data, &[rows, cols]).to_device(Device::Cuda(0));
|
||||
let out = softmax(&x).to_device(Device::Cpu);
|
||||
let result = out.as_slice::<f32>();
|
||||
for r in 0..rows {
|
||||
let row_sum: f32 = result[r * cols..(r + 1) * cols].iter().sum();
|
||||
assert!((row_sum - 1.0).abs() < 1e-5, "softmax row {r} sum = {row_sum}");
|
||||
assert!(
|
||||
(row_sum - 1.0).abs() < 1e-5,
|
||||
"softmax row {r} sum = {row_sum}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -247,8 +261,10 @@ fn test_embedding_f32() {
|
||||
for i in 0..hidden {
|
||||
let expected = table_data[tid as usize * hidden + i];
|
||||
let got = result[seq_idx * hidden + i];
|
||||
assert!((got - expected).abs() < 1e-6,
|
||||
"embedding mismatch at [{seq_idx},{i}]: got {got}, expected {expected}");
|
||||
assert!(
|
||||
(got - expected).abs() < 1e-6,
|
||||
"embedding mismatch at [{seq_idx},{i}]: got {got}, expected {expected}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -270,8 +286,8 @@ fn test_rope_f32() {
|
||||
let mut expected = x_data.clone();
|
||||
cpu_rope(&mut expected, &positions, num_heads, head_dim, theta);
|
||||
|
||||
let x = Tensor::from_slice(&x_data, &[num_tokens, num_heads, head_dim])
|
||||
.to_device(Device::Cuda(0));
|
||||
let x =
|
||||
Tensor::from_slice(&x_data, &[num_tokens, num_heads, head_dim]).to_device(Device::Cuda(0));
|
||||
let cache = RopeCache::new(64, head_dim, theta);
|
||||
rope_inplace(&x, &cache, &positions);
|
||||
|
||||
@@ -292,8 +308,8 @@ fn test_rope_position_0_identity() {
|
||||
.map(|i| (i as f32 + 1.0) * 0.1)
|
||||
.collect();
|
||||
|
||||
let x = Tensor::from_slice(&x_data, &[num_tokens, num_heads, head_dim])
|
||||
.to_device(Device::Cuda(0));
|
||||
let x =
|
||||
Tensor::from_slice(&x_data, &[num_tokens, num_heads, head_dim]).to_device(Device::Cuda(0));
|
||||
let cache = RopeCache::new(64, head_dim, 10000.0);
|
||||
rope_inplace(&x, &cache, &positions);
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use xserv_distributed::{TpContext, UniqueId, get_unique_id};
|
||||
use xserv_model::{loader, GptOss, GraphedGptOssDecoder, ModelConfig, PagedKVCache, BLOCK_SIZE};
|
||||
use xserv_model::{BLOCK_SIZE, GptOss, GraphedGptOssDecoder, ModelConfig, PagedKVCache, loader};
|
||||
use xserv_tensor::{DType, Device};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
@@ -23,8 +23,12 @@ fn main() {
|
||||
|
||||
eprintln!(
|
||||
"gpt-oss-20b: layers={}, hidden={}, heads={}/{} kv, experts={}, top_k={}, vocab={}",
|
||||
config.num_layers(), config.hidden(), config.num_heads(),
|
||||
config.num_kv_heads(), config.num_experts(), config.experts_per_token(),
|
||||
config.num_layers(),
|
||||
config.hidden(),
|
||||
config.num_heads(),
|
||||
config.num_kv_heads(),
|
||||
config.num_experts(),
|
||||
config.experts_per_token(),
|
||||
config.vocab_size
|
||||
);
|
||||
eprintln!("TP world={world}, max_tokens={max_tokens}");
|
||||
@@ -59,17 +63,29 @@ fn main() {
|
||||
let tp0 = Arc::new(TpContext::init(0, world, uid, 0));
|
||||
eprintln!("[rank 0] Loading weights...");
|
||||
let weights = loader::load_model_dir(&model_dir, Device::Cpu);
|
||||
eprintln!("[rank 0] Loaded {} tensors, building model...", weights.len());
|
||||
eprintln!(
|
||||
"[rank 0] Loaded {} tensors, building model...",
|
||||
weights.len()
|
||||
);
|
||||
let model = GptOss::from_weights_tp(config.clone(), weights, 0, world, 0, Some(tp0));
|
||||
let total_blocks = max_blocks_per_seq + 64;
|
||||
let mut cache = PagedKVCache::new_tp(
|
||||
&config, local_kv, total_blocks, 0, 4, max_blocks_per_seq, DType::BF16, 0,
|
||||
&config,
|
||||
local_kv,
|
||||
total_blocks,
|
||||
0,
|
||||
4,
|
||||
max_blocks_per_seq,
|
||||
DType::BF16,
|
||||
0,
|
||||
);
|
||||
eprintln!("[rank 0] Ready.");
|
||||
|
||||
// Prompt
|
||||
let prompt_arg = get_arg::<String>(&args, "--prompt");
|
||||
let prompt = prompt_arg.as_deref().unwrap_or("What is the meaning of life?");
|
||||
let prompt = prompt_arg
|
||||
.as_deref()
|
||||
.unwrap_or("What is the meaning of life?");
|
||||
let token_ids = tokenizer.encode(prompt);
|
||||
eprintln!("Prompt ({} tokens): {prompt}", token_ids.len());
|
||||
|
||||
@@ -83,11 +99,21 @@ fn main() {
|
||||
// (oracle) next token. Removes free-running compounding so it isolates
|
||||
// whether per-position logits agree with the llama.cpp trajectory.
|
||||
if let Some(forced) = get_arg::<String>(&args, "--forced") {
|
||||
let forced_ids: Vec<u32> = forced.split(',').filter_map(|s| s.trim().parse().ok()).collect();
|
||||
let forced_ids: Vec<u32> = forced
|
||||
.split(',')
|
||||
.filter_map(|s| s.trim().parse().ok())
|
||||
.collect();
|
||||
let mut seq = token_ids.clone();
|
||||
seq.extend_from_slice(&forced_ids);
|
||||
// Workers must run the same prefill in lockstep (TP AllReduces match up).
|
||||
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Prefill { tokens: seq.clone(), slot });
|
||||
broadcast_cmd(
|
||||
&worker_txs,
|
||||
&worker_handles,
|
||||
WorkerCmd::Prefill {
|
||||
tokens: seq.clone(),
|
||||
slot,
|
||||
},
|
||||
);
|
||||
let logits = model.forward_prefill_paged(&seq, slot, &mut cache);
|
||||
wait_workers(&worker_handles);
|
||||
let logits_cpu = logits.to_device(Device::Cpu);
|
||||
@@ -99,19 +125,31 @@ fn main() {
|
||||
// position i predicts seq[i+1]; we check the forced region
|
||||
for i in (plen - 1)..(seq.len() - 1) {
|
||||
let row = &data[i * vocab..(i + 1) * vocab];
|
||||
let argmax = row.iter().enumerate()
|
||||
let argmax = row
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
|
||||
.map(|(j, _)| j as u32).unwrap();
|
||||
.map(|(j, _)| j as u32)
|
||||
.unwrap();
|
||||
let expected = seq[i + 1];
|
||||
let ok = argmax == expected;
|
||||
if ok { matches += 1; }
|
||||
total += 1;
|
||||
eprintln!("pos {i}: xserv_argmax={argmax} oracle={expected} {}", if ok {"OK"} else {"DIFF"});
|
||||
if ok {
|
||||
matches += 1;
|
||||
}
|
||||
eprintln!("\nTeacher-forced top-1 agreement: {matches}/{total} = {:.1}%",
|
||||
100.0 * matches as f64 / total as f64);
|
||||
total += 1;
|
||||
eprintln!(
|
||||
"pos {i}: xserv_argmax={argmax} oracle={expected} {}",
|
||||
if ok { "OK" } else { "DIFF" }
|
||||
);
|
||||
}
|
||||
eprintln!(
|
||||
"\nTeacher-forced top-1 agreement: {matches}/{total} = {:.1}%",
|
||||
100.0 * matches as f64 / total as f64
|
||||
);
|
||||
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Shutdown);
|
||||
for (h, _) in worker_handles { h.join().unwrap(); }
|
||||
for (h, _) in worker_handles {
|
||||
h.join().unwrap();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -120,8 +158,18 @@ fn main() {
|
||||
// per-position top-1 agreement bucketed by position. Localizes long-context
|
||||
// decode degradation (which prefill teacher-forcing cannot see).
|
||||
if let Some(forced) = get_arg::<String>(&args, "--forced-decode") {
|
||||
let forced_ids: Vec<u32> = forced.split(',').filter_map(|s| s.trim().parse().ok()).collect();
|
||||
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Prefill { tokens: token_ids.clone(), slot });
|
||||
let forced_ids: Vec<u32> = forced
|
||||
.split(',')
|
||||
.filter_map(|s| s.trim().parse().ok())
|
||||
.collect();
|
||||
broadcast_cmd(
|
||||
&worker_txs,
|
||||
&worker_handles,
|
||||
WorkerCmd::Prefill {
|
||||
tokens: token_ids.clone(),
|
||||
slot,
|
||||
},
|
||||
);
|
||||
let logits = model.forward_prefill_paged(&token_ids, slot, &mut cache);
|
||||
wait_workers(&worker_handles);
|
||||
let mut pred = sample_greedy_last(&logits); // prediction for forced[0]
|
||||
@@ -133,34 +181,55 @@ fn main() {
|
||||
matches += ok as usize;
|
||||
total += 1;
|
||||
let b = i / bucket;
|
||||
if buckets.len() <= b { buckets.push((0, 0)); }
|
||||
if buckets.len() <= b {
|
||||
buckets.push((0, 0));
|
||||
}
|
||||
buckets[b].0 += ok as usize;
|
||||
buckets[b].1 += 1;
|
||||
// Teacher-force: feed the oracle token through the decode path.
|
||||
let pos = cache.seq_len(slot);
|
||||
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Decode {
|
||||
tokens: vec![f], positions: vec![pos], slots: vec![slot],
|
||||
});
|
||||
broadcast_cmd(
|
||||
&worker_txs,
|
||||
&worker_handles,
|
||||
WorkerCmd::Decode {
|
||||
tokens: vec![f],
|
||||
positions: vec![pos],
|
||||
slots: vec![slot],
|
||||
},
|
||||
);
|
||||
let logits = model.forward_decode_paged(&[f], &[pos], &[slot], &mut cache);
|
||||
wait_workers(&worker_handles);
|
||||
pred = sample_greedy_last(&logits);
|
||||
}
|
||||
eprintln!("Teacher-forced DECODE agreement: {matches}/{total} = {:.1}%",
|
||||
100.0 * matches as f64 / total as f64);
|
||||
eprintln!(
|
||||
"Teacher-forced DECODE agreement: {matches}/{total} = {:.1}%",
|
||||
100.0 * matches as f64 / total as f64
|
||||
);
|
||||
for (b, (m, t)) in buckets.iter().enumerate() {
|
||||
eprintln!(" pos[{:>4}..{:<4}]: {m:>3}/{t:<3} = {:.0}%",
|
||||
b * bucket, b * bucket + t, 100.0 * (*m as f64) / (*t as f64));
|
||||
eprintln!(
|
||||
" pos[{:>4}..{:<4}]: {m:>3}/{t:<3} = {:.0}%",
|
||||
b * bucket,
|
||||
b * bucket + t,
|
||||
100.0 * (*m as f64) / (*t as f64)
|
||||
);
|
||||
}
|
||||
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Shutdown);
|
||||
for (h, _) in worker_handles { h.join().unwrap(); }
|
||||
for (h, _) in worker_handles {
|
||||
h.join().unwrap();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Prefill
|
||||
let t0 = Instant::now();
|
||||
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Prefill {
|
||||
tokens: token_ids.clone(), slot,
|
||||
});
|
||||
broadcast_cmd(
|
||||
&worker_txs,
|
||||
&worker_handles,
|
||||
WorkerCmd::Prefill {
|
||||
tokens: token_ids.clone(),
|
||||
slot,
|
||||
},
|
||||
);
|
||||
let logits = model.forward_prefill_paged(&token_ids, slot, &mut cache);
|
||||
wait_workers(&worker_handles);
|
||||
let ttft = t0.elapsed();
|
||||
@@ -178,12 +247,20 @@ fn main() {
|
||||
let text = tokenizer.decode(&[next]);
|
||||
print!("{text}");
|
||||
|
||||
if tokenizer.eos_token_id() == Some(next) { break; }
|
||||
if tokenizer.eos_token_id() == Some(next) {
|
||||
break;
|
||||
}
|
||||
|
||||
let pos = cache.seq_len(slot);
|
||||
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Decode {
|
||||
tokens: vec![next], positions: vec![pos], slots: vec![slot],
|
||||
});
|
||||
broadcast_cmd(
|
||||
&worker_txs,
|
||||
&worker_handles,
|
||||
WorkerCmd::Decode {
|
||||
tokens: vec![next],
|
||||
positions: vec![pos],
|
||||
slots: vec![slot],
|
||||
},
|
||||
);
|
||||
let logits = decoder.decode(&model, &[next], &[pos], &[slot], &mut cache);
|
||||
wait_workers(&worker_handles);
|
||||
|
||||
@@ -196,13 +273,20 @@ fn main() {
|
||||
let gen_tokens = output_tokens.len();
|
||||
let full_text = tokenizer.decode(&output_tokens);
|
||||
eprintln!("\nGenerated text: {full_text}");
|
||||
eprintln!("Token IDs: {:?}", &output_tokens[..output_tokens.len().min(20)]);
|
||||
eprintln!(
|
||||
"Token IDs: {:?}",
|
||||
&output_tokens[..output_tokens.len().min(20)]
|
||||
);
|
||||
let tpot = if gen_tokens > 1 {
|
||||
decode_elapsed.as_secs_f64() * 1000.0 / (gen_tokens - 1) as f64
|
||||
} else { 0.0 };
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
let tok_s = if gen_tokens > 1 {
|
||||
(gen_tokens - 1) as f64 / decode_elapsed.as_secs_f64()
|
||||
} else { 0.0 };
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
eprintln!("\n--- Performance ---");
|
||||
eprintln!("Generated: {} tokens", gen_tokens);
|
||||
@@ -222,8 +306,15 @@ fn main() {
|
||||
#[derive(Clone)]
|
||||
enum WorkerCmd {
|
||||
Register(usize),
|
||||
Prefill { tokens: Vec<u32>, slot: usize },
|
||||
Decode { tokens: Vec<u32>, positions: Vec<usize>, slots: Vec<usize> },
|
||||
Prefill {
|
||||
tokens: Vec<u32>,
|
||||
slot: usize,
|
||||
},
|
||||
Decode {
|
||||
tokens: Vec<u32>,
|
||||
positions: Vec<usize>,
|
||||
slots: Vec<usize>,
|
||||
},
|
||||
Shutdown,
|
||||
}
|
||||
|
||||
@@ -241,12 +332,20 @@ fn worker_loop(
|
||||
let tp = Arc::new(TpContext::init(rank, world, uid, rank as u32));
|
||||
eprintln!("[rank {rank}] Loading weights...");
|
||||
let weights = loader::load_model_dir(&model_dir, Device::Cpu);
|
||||
let model = GptOss::from_weights_tp(config.clone(), weights, rank, world, rank as u32, Some(tp));
|
||||
let model =
|
||||
GptOss::from_weights_tp(config.clone(), weights, rank, world, rank as u32, Some(tp));
|
||||
let local_kv = config.num_kv_heads() / world;
|
||||
let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
let total_blocks = max_blocks_per_seq + 64;
|
||||
let mut cache = PagedKVCache::new_tp(
|
||||
&config, local_kv, total_blocks, 0, 4, max_blocks_per_seq, DType::BF16, rank as u32,
|
||||
&config,
|
||||
local_kv,
|
||||
total_blocks,
|
||||
0,
|
||||
4,
|
||||
max_blocks_per_seq,
|
||||
DType::BF16,
|
||||
rank as u32,
|
||||
);
|
||||
eprintln!("[rank {rank}] Ready.");
|
||||
ack_tx.send(()).unwrap();
|
||||
@@ -260,7 +359,11 @@ fn worker_loop(
|
||||
WorkerCmd::Prefill { tokens, slot } => {
|
||||
let _ = model.forward_prefill_paged(&tokens, slot, &mut cache);
|
||||
}
|
||||
WorkerCmd::Decode { tokens, positions, slots } => {
|
||||
WorkerCmd::Decode {
|
||||
tokens,
|
||||
positions,
|
||||
slots,
|
||||
} => {
|
||||
let _ = decoder.decode(&model, &tokens, &positions, &slots, &mut cache);
|
||||
}
|
||||
WorkerCmd::Shutdown => break,
|
||||
@@ -299,14 +402,15 @@ fn sample_greedy_last(logits: &xserv_tensor::Tensor) -> u32 {
|
||||
let data = logits_cpu.as_slice::<bf16>();
|
||||
let last = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size];
|
||||
|
||||
|
||||
last.iter().enumerate()
|
||||
last.iter()
|
||||
.enumerate()
|
||||
.max_by(|a, b| {
|
||||
let af = a.1.to_f32();
|
||||
let bf = b.1.to_f32();
|
||||
af.partial_cmp(&bf).unwrap_or(std::cmp::Ordering::Equal)
|
||||
})
|
||||
.map(|(i, _)| i as u32).unwrap()
|
||||
.map(|(i, _)| i as u32)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn get_arg<T: std::str::FromStr>(args: &[String], flag: &str) -> Option<T> {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::path::PathBuf;
|
||||
use std::time::Instant;
|
||||
use xserv_model::gpt2::{sample_greedy, KVCache};
|
||||
use xserv_model::{loader, GPT2, ModelConfig};
|
||||
use xserv_model::gpt2::{KVCache, sample_greedy};
|
||||
use xserv_model::{GPT2, ModelConfig, loader};
|
||||
use xserv_tensor::Device;
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
@@ -104,9 +104,15 @@ fn main() {
|
||||
|
||||
let tbt_us = if !token_times_us.is_empty() {
|
||||
token_times_us.iter().sum::<u128>() / token_times_us.len() as u128
|
||||
} else { 0 };
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let total_gen_us: u128 = ttft_us + token_times_us.iter().sum::<u128>();
|
||||
let tpot_us = if num_generated > 0 { total_gen_us / num_generated as u128 } else { 0 };
|
||||
let tpot_us = if num_generated > 0 {
|
||||
total_gen_us / num_generated as u128
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
let gen_text_escaped = generated_text
|
||||
.replace('\\', "\\\\")
|
||||
@@ -124,11 +130,16 @@ fn main() {
|
||||
print!("\"ttft_us\": {ttft_us}, ");
|
||||
print!("\"tbt_us\": {tbt_us}, ");
|
||||
print!("\"tpot_us\": {tpot_us}}}");
|
||||
if i < prompts.len() - 1 { println!(","); } else { println!(); }
|
||||
if i < prompts.len() - 1 {
|
||||
println!(",");
|
||||
} else {
|
||||
println!();
|
||||
}
|
||||
|
||||
eprintln!(
|
||||
"[{}/{}] input={input_len}tok gen={num_generated}tok ttft={:.1}ms tbt={:.1}ms | {}",
|
||||
i + 1, prompts.len(),
|
||||
i + 1,
|
||||
prompts.len(),
|
||||
ttft_us as f64 / 1000.0,
|
||||
tbt_us as f64 / 1000.0,
|
||||
&generated_text.replace('\n', " ")[..generated_text.len().min(60)]
|
||||
@@ -138,12 +149,18 @@ fn main() {
|
||||
}
|
||||
|
||||
fn generate_with_cache(
|
||||
model: &GPT2, config: &ModelConfig, tokenizer: &Tokenizer,
|
||||
input_ids: &[u32], gen_tokens: usize,
|
||||
model: &GPT2,
|
||||
config: &ModelConfig,
|
||||
tokenizer: &Tokenizer,
|
||||
input_ids: &[u32],
|
||||
gen_tokens: usize,
|
||||
) -> (Vec<u32>, u128, Vec<u128>) {
|
||||
let mut cache = KVCache::new(
|
||||
config.num_layers(), config.num_heads(), config.head_dim(),
|
||||
xserv_tensor::DType::F32, Device::Cuda(0),
|
||||
config.num_layers(),
|
||||
config.num_heads(),
|
||||
config.head_dim(),
|
||||
xserv_tensor::DType::F32,
|
||||
Device::Cuda(0),
|
||||
);
|
||||
|
||||
// Prefill
|
||||
@@ -163,15 +180,19 @@ fn generate_with_cache(
|
||||
let next = sample_greedy(&logits);
|
||||
token_times.push(t_start.elapsed().as_micros());
|
||||
generated.push(next);
|
||||
if tokenizer.eos_token_id() == Some(next) { break; }
|
||||
if tokenizer.eos_token_id() == Some(next) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
(generated, ttft_us, token_times)
|
||||
}
|
||||
|
||||
fn generate_no_cache(
|
||||
model: &GPT2, tokenizer: &Tokenizer,
|
||||
input_ids: &[u32], gen_tokens: usize,
|
||||
model: &GPT2,
|
||||
tokenizer: &Tokenizer,
|
||||
input_ids: &[u32],
|
||||
gen_tokens: usize,
|
||||
) -> (Vec<u32>, u128, Vec<u128>) {
|
||||
let mut all_ids = input_ids.to_vec();
|
||||
|
||||
@@ -191,7 +212,9 @@ fn generate_no_cache(
|
||||
token_times.push(t_start.elapsed().as_micros());
|
||||
all_ids.push(next);
|
||||
generated.push(next);
|
||||
if tokenizer.eos_token_id() == Some(next) { break; }
|
||||
if tokenizer.eos_token_id() == Some(next) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
(generated, ttft_us, token_times)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::path::PathBuf;
|
||||
use std::time::Instant;
|
||||
use xserv_model::qwen3::sample_greedy;
|
||||
use xserv_model::{loader, DecodeGraphState, GpuKVCache, ModelConfig, Qwen3};
|
||||
use xserv_model::{DecodeGraphState, GpuKVCache, ModelConfig, Qwen3, loader};
|
||||
use xserv_tensor::{DType, Device};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
@@ -139,18 +139,35 @@ fn main() {
|
||||
} else {
|
||||
// Replay captured graphs
|
||||
let pos = cache.seq_len() as u32;
|
||||
graph.execute(last, pos, &mut cache, &layer_ptrs, embed, config.vocab_size as i32, config.hidden() as i32);
|
||||
graph.execute(
|
||||
last,
|
||||
pos,
|
||||
&mut cache,
|
||||
&layer_ptrs,
|
||||
embed,
|
||||
config.vocab_size as i32,
|
||||
config.hidden() as i32,
|
||||
);
|
||||
cache.advance_seq_len(1);
|
||||
// Read logits from graph buffer
|
||||
let vocab_size = config.vocab_size;
|
||||
let mut logits_bytes = vec![0u8; vocab_size * 2];
|
||||
graph.logits_buffer().copy_to_host(&mut logits_bytes).unwrap();
|
||||
graph
|
||||
.logits_buffer()
|
||||
.copy_to_host(&mut logits_bytes)
|
||||
.unwrap();
|
||||
let logits_data: &[half::bf16] = unsafe {
|
||||
std::slice::from_raw_parts(logits_bytes.as_ptr() as *const half::bf16, vocab_size)
|
||||
std::slice::from_raw_parts(
|
||||
logits_bytes.as_ptr() as *const half::bf16,
|
||||
vocab_size,
|
||||
)
|
||||
};
|
||||
logits_data.iter().enumerate()
|
||||
logits_data
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
|
||||
.map(|(idx, _)| idx as u32).unwrap()
|
||||
.map(|(idx, _)| idx as u32)
|
||||
.unwrap()
|
||||
}
|
||||
} else {
|
||||
let logits = model.forward_gpu_cache(&[last], &mut cache);
|
||||
@@ -159,16 +176,24 @@ fn main() {
|
||||
|
||||
token_times.push(t_start.elapsed().as_micros());
|
||||
generated.push(next);
|
||||
if tokenizer.eos_token_id() == Some(next) { break; }
|
||||
if tokenizer.eos_token_id() == Some(next) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let num_generated = generated.len();
|
||||
let generated_text = tokenizer.decode(&generated);
|
||||
let tbt_us = if !token_times.is_empty() {
|
||||
token_times.iter().sum::<u128>() / token_times.len() as u128
|
||||
} else { 0 };
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let total_gen_us: u128 = ttft_us + token_times.iter().sum::<u128>();
|
||||
let tpot_us = if num_generated > 0 { total_gen_us / num_generated as u128 } else { 0 };
|
||||
let tpot_us = if num_generated > 0 {
|
||||
total_gen_us / num_generated as u128
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
let gen_text_escaped = generated_text
|
||||
.replace('\\', "\\\\")
|
||||
@@ -186,13 +211,18 @@ fn main() {
|
||||
print!("\"ttft_us\": {ttft_us}, ");
|
||||
print!("\"tbt_us\": {tbt_us}, ");
|
||||
print!("\"tpot_us\": {tpot_us}}}");
|
||||
if i < prompts.len() - 1 { println!(","); } else { println!(); }
|
||||
if i < prompts.len() - 1 {
|
||||
println!(",");
|
||||
} else {
|
||||
println!();
|
||||
}
|
||||
|
||||
let display_text = generated_text.replace('\n', " ");
|
||||
let truncated: String = display_text.chars().take(60).collect();
|
||||
eprintln!(
|
||||
"[{}/{}] input={input_len}tok gen={num_generated}tok ttft={:.1}ms tbt={:.1}ms | {}",
|
||||
i + 1, prompts.len(),
|
||||
i + 1,
|
||||
prompts.len(),
|
||||
ttft_us as f64 / 1000.0,
|
||||
tbt_us as f64 / 1000.0,
|
||||
truncated
|
||||
|
||||
@@ -18,7 +18,7 @@ use std::thread;
|
||||
use std::time::Instant;
|
||||
|
||||
use xserv_model::qwen3::sample_greedy;
|
||||
use xserv_model::{loader, ModelConfig, PagedKVCache, Qwen3, BLOCK_SIZE};
|
||||
use xserv_model::{BLOCK_SIZE, ModelConfig, PagedKVCache, Qwen3, loader};
|
||||
use xserv_tensor::{DType, Device};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
@@ -35,8 +35,13 @@ fn main() {
|
||||
std::process::exit(1);
|
||||
}
|
||||
let model_dir = PathBuf::from(&args[1]);
|
||||
let world: usize = arg(&args, "--tp").and_then(|s| s.parse().ok()).unwrap_or(1).max(1);
|
||||
let gen_tokens: usize = arg(&args, "--gen-tokens").and_then(|s| s.parse().ok()).unwrap_or(64);
|
||||
let world: usize = arg(&args, "--tp")
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(1)
|
||||
.max(1);
|
||||
let gen_tokens: usize = arg(&args, "--gen-tokens")
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(64);
|
||||
let devices: Vec<u32> = match arg(&args, "--devices") {
|
||||
Some(s) => s.split(',').filter_map(|d| d.trim().parse().ok()).collect(),
|
||||
None => (0..world as u32).collect(),
|
||||
@@ -67,7 +72,11 @@ fn main() {
|
||||
// Tensors are not Send (their Storage holds a raw GPU pointer), so each rank
|
||||
// thread loads its own CPU copy of the weights and shards in-thread. Loading
|
||||
// is not part of the timed region.
|
||||
let id = if world > 1 { Some(xserv_distributed::get_unique_id()) } else { None };
|
||||
let id = if world > 1 {
|
||||
Some(xserv_distributed::get_unique_id())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let handles: Vec<_> = (0..world)
|
||||
.map(|rank| {
|
||||
@@ -76,7 +85,9 @@ fn main() {
|
||||
let prompt_ids = prompt_ids.clone();
|
||||
let device = devices[rank];
|
||||
thread::spawn(move || {
|
||||
run_rank(rank, world, device, id, config, model_dir, prompt_ids, gen_tokens, eos)
|
||||
run_rank(
|
||||
rank, world, device, id, config, model_dir, prompt_ids, gen_tokens, eos,
|
||||
)
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
@@ -91,7 +102,10 @@ fn main() {
|
||||
|
||||
let results = rank0.expect("rank 0 produced no results");
|
||||
println!("\n=== TP={world} (devices {devices:?}) — Qwen3 E2E benchmark ===");
|
||||
println!("{:<45} {:>10} {:>12} {:>8}", "prompt", "TTFT(ms)", "decode tok/s", "gen");
|
||||
println!(
|
||||
"{:<45} {:>10} {:>12} {:>8}",
|
||||
"prompt", "TTFT(ms)", "decode tok/s", "gen"
|
||||
);
|
||||
let mut tps_sum = 0.0;
|
||||
for (i, r) in results.iter().enumerate() {
|
||||
let text = tokenizer.decode(&r.gen_ids).replace('\n', " ");
|
||||
@@ -99,16 +113,29 @@ fn main() {
|
||||
let p: String = prompts[i].chars().take(43).collect();
|
||||
println!(
|
||||
"{:<45} {:>10.1} {:>12.1} {:>8} | {}",
|
||||
p, r.ttft_ms, r.decode_tok_s, r.gen_ids.len(), short
|
||||
p,
|
||||
r.ttft_ms,
|
||||
r.decode_tok_s,
|
||||
r.gen_ids.len(),
|
||||
short
|
||||
);
|
||||
tps_sum += r.decode_tok_s;
|
||||
}
|
||||
println!("--- mean decode throughput: {:.1} tok/s ---", tps_sum / results.len() as f64);
|
||||
println!(
|
||||
"--- mean decode throughput: {:.1} tok/s ---",
|
||||
tps_sum / results.len() as f64
|
||||
);
|
||||
|
||||
// Machine-readable line for cross-TP correctness diffing (rank 0 token ids).
|
||||
let all_ids: Vec<String> = results
|
||||
.iter()
|
||||
.map(|r| r.gen_ids.iter().map(|x| x.to_string()).collect::<Vec<_>>().join(","))
|
||||
.map(|r| {
|
||||
r.gen_ids
|
||||
.iter()
|
||||
.map(|x| x.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
})
|
||||
.collect();
|
||||
println!("CORRECTNESS_IDS tp={world} {}", all_ids.join(" | "));
|
||||
}
|
||||
@@ -126,7 +153,12 @@ fn run_rank(
|
||||
) -> Option<Vec<PromptResult>> {
|
||||
// Bind this thread to its GPU and set up the TP communicator.
|
||||
let tp = if world > 1 {
|
||||
Some(Arc::new(xserv_distributed::TpContext::init(rank, world, id.unwrap(), device)))
|
||||
Some(Arc::new(xserv_distributed::TpContext::init(
|
||||
rank,
|
||||
world,
|
||||
id.unwrap(),
|
||||
device,
|
||||
)))
|
||||
} else {
|
||||
xserv_cuda::device::set_device(device).unwrap();
|
||||
None
|
||||
@@ -142,7 +174,14 @@ fn run_rank(
|
||||
let max_blocks_per_seq = max_seq.div_ceil(BLOCK_SIZE);
|
||||
let total_blocks = max_blocks_per_seq + 8;
|
||||
let mut cache = PagedKVCache::new_tp(
|
||||
&config, local_kv, total_blocks, 0, 1, max_blocks_per_seq, DType::BF16, device,
|
||||
&config,
|
||||
local_kv,
|
||||
total_blocks,
|
||||
0,
|
||||
1,
|
||||
max_blocks_per_seq,
|
||||
DType::BF16,
|
||||
device,
|
||||
);
|
||||
|
||||
// Warmup (init kernels / allocator / NCCL channels) — not timed.
|
||||
@@ -177,12 +216,20 @@ fn run_rank(
|
||||
steps += 1;
|
||||
}
|
||||
let decode_s = t1.elapsed().as_secs_f64();
|
||||
let decode_tok_s = if steps > 0 && decode_s > 0.0 { steps as f64 / decode_s } else { 0.0 };
|
||||
let decode_tok_s = if steps > 0 && decode_s > 0.0 {
|
||||
steps as f64 / decode_s
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
cache.free_sequence(0);
|
||||
|
||||
if rank == 0 {
|
||||
out.push(PromptResult { gen_ids: generated, ttft_ms, decode_tok_s });
|
||||
out.push(PromptResult {
|
||||
gen_ids: generated,
|
||||
ttft_ms,
|
||||
decode_tok_s,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -190,5 +237,8 @@ fn run_rank(
|
||||
}
|
||||
|
||||
fn arg<'a>(args: &'a [String], flag: &str) -> Option<&'a str> {
|
||||
args.iter().position(|a| a == flag).and_then(|i| args.get(i + 1)).map(|s| s.as_str())
|
||||
args.iter()
|
||||
.position(|a| a == flag)
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.map(|s| s.as_str())
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use half::bf16;
|
||||
use std::path::PathBuf;
|
||||
use xserv_model::{loader, KVCache, ModelConfig, Qwen3};
|
||||
use xserv_model::{KVCache, ModelConfig, Qwen3, loader};
|
||||
use xserv_tensor::{DType, Device};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
use half::bf16;
|
||||
|
||||
fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
@@ -20,8 +20,11 @@ fn main() {
|
||||
eprintln!("Token IDs: {token_ids:?}");
|
||||
|
||||
let mut cache = KVCache::new(
|
||||
config.num_layers(), config.num_kv_heads(), config.head_dim(),
|
||||
DType::BF16, Device::Cuda(0),
|
||||
config.num_layers(),
|
||||
config.num_kv_heads(),
|
||||
config.head_dim(),
|
||||
DType::BF16,
|
||||
Device::Cuda(0),
|
||||
);
|
||||
let logits = model.forward_with_cache(&token_ids, &mut cache);
|
||||
let logits_cpu = logits.to_device(Device::Cpu);
|
||||
@@ -31,7 +34,9 @@ fn main() {
|
||||
|
||||
// Print top-20 logits for the last position
|
||||
let last_row = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size];
|
||||
let mut indexed: Vec<(usize, f32)> = last_row.iter().enumerate()
|
||||
let mut indexed: Vec<(usize, f32)> = last_row
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, v)| (i, v.to_f32()))
|
||||
.collect();
|
||||
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
use std::io::{self, IsTerminal, Read, Write};
|
||||
use std::path::PathBuf;
|
||||
|
||||
use std::sync::{mpsc, Arc};
|
||||
use std::sync::{Arc, mpsc};
|
||||
use std::thread;
|
||||
|
||||
use xserv_model::{GraphedGptOssDecoder, loader, sample, sample_greedy_penalized, GptOss, ModelConfig, PagedKVCache, Qwen3, SamplingParams, BLOCK_SIZE};
|
||||
use xserv_model::{
|
||||
BLOCK_SIZE, GptOss, GraphedGptOssDecoder, ModelConfig, PagedKVCache, Qwen3, SamplingParams,
|
||||
loader, sample, sample_greedy_penalized,
|
||||
};
|
||||
use xserv_tensor::{DType, Device};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
@@ -14,13 +17,24 @@ enum ChatModel {
|
||||
}
|
||||
|
||||
impl ChatModel {
|
||||
fn forward_prefill_paged(&self, tokens: &[u32], slot: usize, cache: &mut PagedKVCache) -> xserv_tensor::Tensor {
|
||||
fn forward_prefill_paged(
|
||||
&self,
|
||||
tokens: &[u32],
|
||||
slot: usize,
|
||||
cache: &mut PagedKVCache,
|
||||
) -> xserv_tensor::Tensor {
|
||||
match self {
|
||||
ChatModel::Qwen3(m) => m.forward_prefill_paged(tokens, slot, cache),
|
||||
ChatModel::GptOss(m) => m.forward_prefill_paged(tokens, slot, cache),
|
||||
}
|
||||
}
|
||||
fn forward_decode_paged(&self, tokens: &[u32], positions: &[usize], slots: &[usize], cache: &mut PagedKVCache) -> xserv_tensor::Tensor {
|
||||
fn forward_decode_paged(
|
||||
&self,
|
||||
tokens: &[u32],
|
||||
positions: &[usize],
|
||||
slots: &[usize],
|
||||
cache: &mut PagedKVCache,
|
||||
) -> xserv_tensor::Tensor {
|
||||
match self {
|
||||
ChatModel::Qwen3(m) => m.forward_decode_paged(tokens, positions, slots, cache),
|
||||
ChatModel::GptOss(m) => m.forward_decode_paged(tokens, positions, slots, cache),
|
||||
@@ -33,8 +47,15 @@ impl ChatModel {
|
||||
enum TpCommand {
|
||||
Register(usize),
|
||||
Free(usize),
|
||||
Prefill { tokens: Vec<u32>, slot: usize },
|
||||
Decode { tokens: Vec<u32>, positions: Vec<usize>, slots: Vec<usize> },
|
||||
Prefill {
|
||||
tokens: Vec<u32>,
|
||||
slot: usize,
|
||||
},
|
||||
Decode {
|
||||
tokens: Vec<u32>,
|
||||
positions: Vec<usize>,
|
||||
slots: Vec<usize>,
|
||||
},
|
||||
}
|
||||
|
||||
struct TpHandle {
|
||||
@@ -56,7 +77,8 @@ impl TpHandle {
|
||||
}
|
||||
|
||||
fn tp_worker_loop(
|
||||
rank: usize, world: usize,
|
||||
rank: usize,
|
||||
world: usize,
|
||||
id: xserv_distributed::UniqueId,
|
||||
model_dir: std::path::PathBuf,
|
||||
config: ModelConfig,
|
||||
@@ -64,29 +86,68 @@ fn tp_worker_loop(
|
||||
cmd_rx: mpsc::Receiver<TpCommand>,
|
||||
ack_tx: mpsc::Sender<()>,
|
||||
) {
|
||||
let tp = Arc::new(xserv_distributed::TpContext::init(rank, world, id, rank as u32));
|
||||
let tp = Arc::new(xserv_distributed::TpContext::init(
|
||||
rank,
|
||||
world,
|
||||
id,
|
||||
rank as u32,
|
||||
));
|
||||
let weights = loader::load_model_dir(&model_dir, Device::Cpu);
|
||||
let model = if config.is_moe() {
|
||||
ChatModel::GptOss(GptOss::from_weights_tp(config.clone(), weights, rank, world, rank as u32, Some(tp)))
|
||||
ChatModel::GptOss(GptOss::from_weights_tp(
|
||||
config.clone(),
|
||||
weights,
|
||||
rank,
|
||||
world,
|
||||
rank as u32,
|
||||
Some(tp),
|
||||
))
|
||||
} else {
|
||||
ChatModel::Qwen3(Qwen3::from_weights_tp(config.clone(), weights, rank, world, rank as u32, Some(tp)))
|
||||
ChatModel::Qwen3(Qwen3::from_weights_tp(
|
||||
config.clone(),
|
||||
weights,
|
||||
rank,
|
||||
world,
|
||||
rank as u32,
|
||||
Some(tp),
|
||||
))
|
||||
};
|
||||
let local_kv = config.num_kv_heads() / world;
|
||||
let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
let total_blocks = max_blocks_per_seq + 8;
|
||||
let mut cache = PagedKVCache::new_tp(
|
||||
&config, local_kv, total_blocks, 0, 1, max_blocks_per_seq, DType::BF16, rank as u32,
|
||||
&config,
|
||||
local_kv,
|
||||
total_blocks,
|
||||
0,
|
||||
1,
|
||||
max_blocks_per_seq,
|
||||
DType::BF16,
|
||||
rank as u32,
|
||||
);
|
||||
let mut decoder = GraphedGptOssDecoder::new();
|
||||
while let Ok(cmd) = cmd_rx.recv() {
|
||||
match cmd {
|
||||
TpCommand::Register(slot) => { let _ = cache.register_sequence(slot); }
|
||||
TpCommand::Register(slot) => {
|
||||
let _ = cache.register_sequence(slot);
|
||||
}
|
||||
TpCommand::Free(slot) => cache.free_sequence(slot),
|
||||
TpCommand::Prefill { tokens, slot } => {
|
||||
let _ = model.forward_prefill_paged(&tokens, slot, &mut cache);
|
||||
}
|
||||
TpCommand::Decode { tokens, positions, slots } => {
|
||||
let _ = chat_decode(&model, &mut decoder, &tokens, &positions, &slots, &mut cache);
|
||||
TpCommand::Decode {
|
||||
tokens,
|
||||
positions,
|
||||
slots,
|
||||
} => {
|
||||
let _ = chat_decode(
|
||||
&model,
|
||||
&mut decoder,
|
||||
&tokens,
|
||||
&positions,
|
||||
&slots,
|
||||
&mut cache,
|
||||
);
|
||||
}
|
||||
}
|
||||
let _ = ack_tx.send(());
|
||||
@@ -221,7 +282,13 @@ fn read_line_edited(prompt: &str) -> Line {
|
||||
}
|
||||
b => {
|
||||
// UTF-8 multi-byte: read the continuation bytes for this char.
|
||||
let extra = if b >= 0xF0 { 3 } else if b >= 0xE0 { 2 } else { 1 };
|
||||
let extra = if b >= 0xF0 {
|
||||
3
|
||||
} else if b >= 0xE0 {
|
||||
2
|
||||
} else {
|
||||
1
|
||||
};
|
||||
let mut bytes = vec![b];
|
||||
let mut cont = [0u8; 1];
|
||||
let mut ok = true;
|
||||
@@ -275,7 +342,8 @@ fn main() {
|
||||
if world > 1 {
|
||||
assert!(
|
||||
config.num_kv_heads() % world == 0,
|
||||
"num_kv_heads {} not divisible by tp {world}", config.num_kv_heads()
|
||||
"num_kv_heads {} not divisible by tp {world}",
|
||||
config.num_kv_heads()
|
||||
);
|
||||
}
|
||||
|
||||
@@ -290,7 +358,16 @@ fn main() {
|
||||
let model_dir = opts.model_dir.clone();
|
||||
let config = config.clone();
|
||||
thread::spawn(move || {
|
||||
tp_worker_loop(rank, world, id, model_dir, config, max_seq_len, ctx_rx, ack_tx);
|
||||
tp_worker_loop(
|
||||
rank,
|
||||
world,
|
||||
id,
|
||||
model_dir,
|
||||
config,
|
||||
max_seq_len,
|
||||
ctx_rx,
|
||||
ack_tx,
|
||||
);
|
||||
});
|
||||
}
|
||||
eprintln!("Loading weights (tp={world})...");
|
||||
@@ -298,14 +375,37 @@ fn main() {
|
||||
let weights = loader::load_model_dir(&opts.model_dir, Device::Cpu);
|
||||
eprintln!("Loaded {} tensors", weights.len());
|
||||
let m = if is_moe {
|
||||
ChatModel::GptOss(GptOss::from_weights_tp(config.clone(), weights, 0, world, 0, Some(tp)))
|
||||
ChatModel::GptOss(GptOss::from_weights_tp(
|
||||
config.clone(),
|
||||
weights,
|
||||
0,
|
||||
world,
|
||||
0,
|
||||
Some(tp),
|
||||
))
|
||||
} else {
|
||||
ChatModel::Qwen3(Qwen3::from_weights_tp(config.clone(), weights, 0, world, 0, Some(tp)))
|
||||
ChatModel::Qwen3(Qwen3::from_weights_tp(
|
||||
config.clone(),
|
||||
weights,
|
||||
0,
|
||||
world,
|
||||
0,
|
||||
Some(tp),
|
||||
))
|
||||
};
|
||||
let local_kv = config.num_kv_heads() / world;
|
||||
let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
let total_blocks = max_blocks_per_seq + 8;
|
||||
let c = PagedKVCache::new_tp(&config, local_kv, total_blocks, 0, 1, max_blocks_per_seq, DType::BF16, 0);
|
||||
let c = PagedKVCache::new_tp(
|
||||
&config,
|
||||
local_kv,
|
||||
total_blocks,
|
||||
0,
|
||||
1,
|
||||
max_blocks_per_seq,
|
||||
DType::BF16,
|
||||
0,
|
||||
);
|
||||
let h = TpHandle { cmd_txs, ack_rx };
|
||||
(m, c, Some(h))
|
||||
} else {
|
||||
@@ -323,7 +423,10 @@ fn main() {
|
||||
|
||||
let tokenizer = Tokenizer::from_file(&opts.model_dir.join("tokenizer.json"));
|
||||
let mut decoder = GraphedGptOssDecoder::new();
|
||||
if let Some(h) = &tp_handle { h.send(TpCommand::Register(SLOT)); h.wait(); }
|
||||
if let Some(h) = &tp_handle {
|
||||
h.send(TpCommand::Register(SLOT));
|
||||
h.wait();
|
||||
}
|
||||
cache.register_sequence(SLOT).expect("register chat slot");
|
||||
let use_color = opts.color && io::stdout().is_terminal();
|
||||
|
||||
@@ -365,11 +468,8 @@ fn main() {
|
||||
if is_moe {
|
||||
// Harmony multi-turn: re-render the whole conversation (prior
|
||||
// analysis dropped) and re-prefill into a freshly cleared slot.
|
||||
let prompt = build_conversation_gpt_oss(
|
||||
opts.system_prompt.as_deref(),
|
||||
&moe_history,
|
||||
input,
|
||||
);
|
||||
let prompt =
|
||||
build_conversation_gpt_oss(opts.system_prompt.as_deref(), &moe_history, input);
|
||||
let prompt_tokens = tokenizer.encode(&prompt);
|
||||
if prompt_tokens.is_empty() {
|
||||
continue;
|
||||
@@ -386,8 +486,17 @@ fn main() {
|
||||
print!("assistant> ");
|
||||
io::stdout().flush().unwrap();
|
||||
let (_finish, answer) = generate_with_paged_cache(
|
||||
&model, &mut decoder, &mut cache, &tokenizer, &prompt_tokens, &opts.sampling,
|
||||
max_new_tokens, use_color, &tp_handle, is_moe, opts.enable_thinking,
|
||||
&model,
|
||||
&mut decoder,
|
||||
&mut cache,
|
||||
&tokenizer,
|
||||
&prompt_tokens,
|
||||
&opts.sampling,
|
||||
max_new_tokens,
|
||||
use_color,
|
||||
&tp_handle,
|
||||
is_moe,
|
||||
opts.enable_thinking,
|
||||
);
|
||||
moe_history.push((input.to_string(), answer));
|
||||
println!();
|
||||
@@ -436,10 +545,24 @@ fn main() {
|
||||
);
|
||||
match finish {
|
||||
Finish::Stop { token_id } => {
|
||||
append_after_stop(&model, &mut cache, &tokenizer, max_seq_len, token_id, &tp_handle);
|
||||
append_after_stop(
|
||||
&model,
|
||||
&mut cache,
|
||||
&tokenizer,
|
||||
max_seq_len,
|
||||
token_id,
|
||||
&tp_handle,
|
||||
);
|
||||
}
|
||||
Finish::Length => {
|
||||
append_text_to_cache(&model, &mut cache, &tokenizer, max_seq_len, "<|im_end|>\n", &tp_handle);
|
||||
append_text_to_cache(
|
||||
&model,
|
||||
&mut cache,
|
||||
&tokenizer,
|
||||
max_seq_len,
|
||||
"<|im_end|>\n",
|
||||
&tp_handle,
|
||||
);
|
||||
}
|
||||
}
|
||||
println!();
|
||||
@@ -448,9 +571,15 @@ fn main() {
|
||||
|
||||
/// Free and re-register the single chat KV slot (clears all cached context).
|
||||
fn reset_slot(cache: &mut PagedKVCache, tp: &Option<TpHandle>) {
|
||||
if let Some(h) = tp { h.send(TpCommand::Free(SLOT)); h.wait(); }
|
||||
if let Some(h) = tp {
|
||||
h.send(TpCommand::Free(SLOT));
|
||||
h.wait();
|
||||
}
|
||||
cache.free_sequence(SLOT);
|
||||
if let Some(h) = tp { h.send(TpCommand::Register(SLOT)); h.wait(); }
|
||||
if let Some(h) = tp {
|
||||
h.send(TpCommand::Register(SLOT));
|
||||
h.wait();
|
||||
}
|
||||
cache.register_sequence(SLOT).expect("register chat slot");
|
||||
}
|
||||
|
||||
@@ -588,7 +717,15 @@ fn new_paged_cache(config: &ModelConfig, max_seq_len: usize) -> PagedKVCache {
|
||||
let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
let total_blocks = (max_blocks_per_seq + 1).max(2);
|
||||
// Single-slot interactive CLI: no swap pool (cpu_total_blocks = 0).
|
||||
PagedKVCache::new(config, total_blocks, 0, 1, max_blocks_per_seq, DType::BF16, 0)
|
||||
PagedKVCache::new(
|
||||
config,
|
||||
total_blocks,
|
||||
0,
|
||||
1,
|
||||
max_blocks_per_seq,
|
||||
DType::BF16,
|
||||
0,
|
||||
)
|
||||
}
|
||||
|
||||
fn build_turn_prompt(
|
||||
@@ -668,7 +805,10 @@ fn build_conversation_gpt_oss(
|
||||
/// civil-calendar conversion (same algorithm the server uses for strftime_now).
|
||||
fn today_ymd() -> String {
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
let secs = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs();
|
||||
let secs = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
let z = (secs / 86400) as i64 + 719468;
|
||||
let era = (if z >= 0 { z } else { z - 146096 }) / 146097;
|
||||
let doe = z - era * 146097;
|
||||
@@ -709,12 +849,32 @@ fn generate_with_paged_cache(
|
||||
is_moe: bool,
|
||||
enable_thinking: bool,
|
||||
) -> (Finish, String) {
|
||||
let harmony_end_id = if is_moe { tokenizer.special_token_id("<|end|>") } else { None };
|
||||
let harmony_channel_id = if is_moe { tokenizer.special_token_id("<|channel|>") } else { None };
|
||||
let harmony_message_id = if is_moe { tokenizer.special_token_id("<|message|>") } else { None };
|
||||
let harmony_end_id = if is_moe {
|
||||
tokenizer.special_token_id("<|end|>")
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let harmony_channel_id = if is_moe {
|
||||
tokenizer.special_token_id("<|channel|>")
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let harmony_message_id = if is_moe {
|
||||
tokenizer.special_token_id("<|message|>")
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let harmony_special: Vec<u32> = if is_moe {
|
||||
["<|channel|>", "<|start|>", "<|end|>", "<|message|>", "<|return|>"]
|
||||
.iter().filter_map(|s| tokenizer.special_token_id(s)).collect()
|
||||
[
|
||||
"<|channel|>",
|
||||
"<|start|>",
|
||||
"<|end|>",
|
||||
"<|message|>",
|
||||
"<|return|>",
|
||||
]
|
||||
.iter()
|
||||
.filter_map(|s| tokenizer.special_token_id(s))
|
||||
.collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
@@ -722,18 +882,29 @@ fn generate_with_paged_cache(
|
||||
// "analysis" channel is rendered as thinking (gray). After <|channel|>
|
||||
// we read the channel name tokens until <|message|>.
|
||||
#[derive(PartialEq, Clone, Copy)]
|
||||
enum HarmonyState { Normal, ReadingChannel, InAnalysis, InFinal }
|
||||
let mut hstate = if is_moe { HarmonyState::InFinal } else { HarmonyState::Normal };
|
||||
enum HarmonyState {
|
||||
Normal,
|
||||
ReadingChannel,
|
||||
InAnalysis,
|
||||
InFinal,
|
||||
}
|
||||
let mut hstate = if is_moe {
|
||||
HarmonyState::InFinal
|
||||
} else {
|
||||
HarmonyState::Normal
|
||||
};
|
||||
|
||||
// Off by default. A repetition penalty over a harmony stream penalizes the
|
||||
// control tokens (<|channel|>, <|message|>, <|start|>) that MUST repeat to
|
||||
// open the final channel — so a non-1.0 default makes gpt-oss stop right
|
||||
// after the analysis block, before emitting any answer. Opt in via the env
|
||||
// var if you want it for plain (non-harmony) generation.
|
||||
let rep_penalty: f32 = std::env::var("XSERV_REP_PENALTY").ok()
|
||||
let rep_penalty: f32 = std::env::var("XSERV_REP_PENALTY")
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(1.0);
|
||||
let rep_window: usize = std::env::var("XSERV_REP_WINDOW").ok()
|
||||
let rep_window: usize = std::env::var("XSERV_REP_WINDOW")
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(512);
|
||||
let mut history: Vec<u32> = Vec::new();
|
||||
@@ -747,9 +918,16 @@ fn generate_with_paged_cache(
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(h) = tp { h.send(TpCommand::Prefill { tokens: prompt_tokens.to_vec(), slot: SLOT }); }
|
||||
if let Some(h) = tp {
|
||||
h.send(TpCommand::Prefill {
|
||||
tokens: prompt_tokens.to_vec(),
|
||||
slot: SLOT,
|
||||
});
|
||||
}
|
||||
let logits = model.forward_prefill_paged(prompt_tokens, SLOT, cache);
|
||||
if let Some(h) = tp { h.wait(); }
|
||||
if let Some(h) = tp {
|
||||
h.wait();
|
||||
}
|
||||
let mut next = pick(&logits, sampling, &history);
|
||||
let mut decode_buffer = Vec::new();
|
||||
let mut in_thinking = false;
|
||||
@@ -762,9 +940,17 @@ fn generate_with_paged_cache(
|
||||
|
||||
for _ in 0..max_tokens {
|
||||
let position = cache.seq_len(SLOT);
|
||||
if let Some(h) = tp { h.send(TpCommand::Decode { tokens: vec![next], positions: vec![position], slots: vec![SLOT] }); }
|
||||
if let Some(h) = tp {
|
||||
h.send(TpCommand::Decode {
|
||||
tokens: vec![next],
|
||||
positions: vec![position],
|
||||
slots: vec![SLOT],
|
||||
});
|
||||
}
|
||||
let logits = chat_decode(model, decoder, &[next], &[position], &[SLOT], cache);
|
||||
if let Some(h) = tp { h.wait(); }
|
||||
if let Some(h) = tp {
|
||||
h.wait();
|
||||
}
|
||||
if tokenizer.is_eos(next) {
|
||||
print_stream_text(
|
||||
&tokenizer.flush_decode_stream(&mut decode_buffer),
|
||||
@@ -775,7 +961,10 @@ fn generate_with_paged_cache(
|
||||
print_stream_text("\n</think>\n\n", true, use_color);
|
||||
}
|
||||
io::stdout().flush().unwrap();
|
||||
return (Finish::Stop { token_id: next }, tokenizer.decode(&answer_ids));
|
||||
return (
|
||||
Finish::Stop { token_id: next },
|
||||
tokenizer.decode(&answer_ids),
|
||||
);
|
||||
}
|
||||
if harmony_end_id == Some(next) {
|
||||
// <|end|> closes current segment; if in final channel, we're done
|
||||
@@ -786,7 +975,10 @@ fn generate_with_paged_cache(
|
||||
);
|
||||
if hstate == HarmonyState::InFinal {
|
||||
io::stdout().flush().unwrap();
|
||||
return (Finish::Stop { token_id: next }, tokenizer.decode(&answer_ids));
|
||||
return (
|
||||
Finish::Stop { token_id: next },
|
||||
tokenizer.decode(&answer_ids),
|
||||
);
|
||||
}
|
||||
// Closing a thinking (analysis/commentary) channel: emit the </think>
|
||||
// marker so it renders like Qwen3's thinking block.
|
||||
@@ -842,7 +1034,13 @@ fn generate_with_paged_cache(
|
||||
// Analysis channel = the model's reasoning. With --think, show it as a
|
||||
// thinking block (gray if color); otherwise suppress it (answer only).
|
||||
if show_thinking {
|
||||
print_generated_token(tokenizer, next, &mut decode_buffer, &mut in_thinking, use_color);
|
||||
print_generated_token(
|
||||
tokenizer,
|
||||
next,
|
||||
&mut decode_buffer,
|
||||
&mut in_thinking,
|
||||
use_color,
|
||||
);
|
||||
io::stdout().flush().unwrap();
|
||||
}
|
||||
next = pick(&logits, sampling, &history);
|
||||
@@ -904,9 +1102,16 @@ fn append_text_to_cache(
|
||||
if tokens.is_empty() || cache.seq_len(SLOT) + tokens.len() > max_seq_len {
|
||||
return;
|
||||
}
|
||||
if let Some(h) = tp { h.send(TpCommand::Prefill { tokens: tokens.clone(), slot: SLOT }); }
|
||||
if let Some(h) = tp {
|
||||
h.send(TpCommand::Prefill {
|
||||
tokens: tokens.clone(),
|
||||
slot: SLOT,
|
||||
});
|
||||
}
|
||||
let _ = model.forward_prefill_paged(&tokens, SLOT, cache);
|
||||
if let Some(h) = tp { h.wait(); }
|
||||
if let Some(h) = tp {
|
||||
h.wait();
|
||||
}
|
||||
}
|
||||
|
||||
fn print_generated_token(
|
||||
@@ -952,4 +1157,3 @@ fn print_stream_text(text: &str, in_thinking: bool, use_color: bool) {
|
||||
print!("{text}");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::io::{self, Write};
|
||||
use std::path::PathBuf;
|
||||
use xserv_model::{loader, KVCache, ModelConfig, PagedKVCache, BLOCK_SIZE};
|
||||
use xserv_model::{BLOCK_SIZE, KVCache, ModelConfig, PagedKVCache, loader};
|
||||
use xserv_tensor::{DType, Device};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
@@ -21,14 +21,21 @@ fn main() {
|
||||
|
||||
xserv_cuda::device::set_device(0).unwrap();
|
||||
let info = xserv_cuda::device::device_info(0).unwrap();
|
||||
eprintln!("GPU: {} ({} MB free)", info.name, info.free_memory / 1024 / 1024);
|
||||
eprintln!(
|
||||
"GPU: {} ({} MB free)",
|
||||
info.name,
|
||||
info.free_memory / 1024 / 1024
|
||||
);
|
||||
|
||||
let config = ModelConfig::from_file(&model_dir.join("config.json"));
|
||||
let model_type = config.model_type.as_deref().unwrap_or("unknown");
|
||||
eprintln!(
|
||||
"Model: {model_type}, layers={}, hidden={}, heads={}/{} kv, vocab={}",
|
||||
config.num_layers(), config.hidden(), config.num_heads(),
|
||||
config.num_kv_heads(), config.vocab_size
|
||||
config.num_layers(),
|
||||
config.hidden(),
|
||||
config.num_heads(),
|
||||
config.num_kv_heads(),
|
||||
config.vocab_size
|
||||
);
|
||||
|
||||
eprintln!("Loading weights...");
|
||||
@@ -37,7 +44,11 @@ fn main() {
|
||||
|
||||
let is_qwen3 = model_type.contains("qwen");
|
||||
let is_gpt_oss = model_type.contains("gpt_oss");
|
||||
let dtype = if is_qwen3 || is_gpt_oss { DType::BF16 } else { DType::F32 };
|
||||
let dtype = if is_qwen3 || is_gpt_oss {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
|
||||
// Build model
|
||||
enum Model {
|
||||
@@ -60,10 +71,16 @@ fn main() {
|
||||
print!("xserv> ");
|
||||
io::stdout().flush().unwrap();
|
||||
let mut input = String::new();
|
||||
if io::stdin().read_line(&mut input).unwrap() == 0 { break; }
|
||||
if io::stdin().read_line(&mut input).unwrap() == 0 {
|
||||
break;
|
||||
}
|
||||
let input = input.trim();
|
||||
if input.is_empty() { continue; }
|
||||
if input == "quit" || input == "exit" { break; }
|
||||
if input.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if input == "quit" || input == "exit" {
|
||||
break;
|
||||
}
|
||||
|
||||
let token_ids = tokenizer.encode(input);
|
||||
|
||||
@@ -73,12 +90,21 @@ fn main() {
|
||||
let max_blocks_per_seq = (max_seq + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
let total_blocks = max_blocks_per_seq + 64;
|
||||
let mut paged_cache = PagedKVCache::new(
|
||||
&config, total_blocks, 0, 4, max_blocks_per_seq, DType::BF16, 0,
|
||||
&config,
|
||||
total_blocks,
|
||||
0,
|
||||
4,
|
||||
max_blocks_per_seq,
|
||||
DType::BF16,
|
||||
0,
|
||||
);
|
||||
let slot = 0;
|
||||
paged_cache.register_sequence(slot).expect("register slot");
|
||||
|
||||
let model = match &model { Model::GptOss(m) => m, _ => unreachable!() };
|
||||
let model = match &model {
|
||||
Model::GptOss(m) => m,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let logits = model.forward_prefill_paged(&token_ids, slot, &mut paged_cache);
|
||||
let mut next = sample_greedy_last(&logits);
|
||||
|
||||
@@ -90,20 +116,28 @@ fn main() {
|
||||
print!("{text}");
|
||||
io::stdout().flush().unwrap();
|
||||
|
||||
if tokenizer.eos_token_id() == Some(next) { break; }
|
||||
if tokenizer.eos_token_id() == Some(next) {
|
||||
break;
|
||||
}
|
||||
|
||||
let pos = paged_cache.seq_len(slot);
|
||||
let logits = model.forward_decode_paged(
|
||||
&[next], &[pos], &[slot], &mut paged_cache,
|
||||
);
|
||||
let logits = model.forward_decode_paged(&[next], &[pos], &[slot], &mut paged_cache);
|
||||
next = sample_greedy_last(&logits);
|
||||
}
|
||||
println!();
|
||||
paged_cache.free_sequence(slot);
|
||||
} else {
|
||||
let kv_heads = if is_qwen3 { config.num_kv_heads() } else { config.num_heads() };
|
||||
let kv_heads = if is_qwen3 {
|
||||
config.num_kv_heads()
|
||||
} else {
|
||||
config.num_heads()
|
||||
};
|
||||
let mut cache = KVCache::new(
|
||||
config.num_layers(), kv_heads, config.head_dim(), dtype, Device::Cuda(0),
|
||||
config.num_layers(),
|
||||
kv_heads,
|
||||
config.head_dim(),
|
||||
dtype,
|
||||
Device::Cuda(0),
|
||||
);
|
||||
|
||||
let logits = match &model {
|
||||
@@ -125,7 +159,9 @@ fn main() {
|
||||
print!("{text}");
|
||||
io::stdout().flush().unwrap();
|
||||
|
||||
if tokenizer.eos_token_id() == Some(next) { break; }
|
||||
if tokenizer.eos_token_id() == Some(next) {
|
||||
break;
|
||||
}
|
||||
|
||||
let logits = match &model {
|
||||
Model::GPT2(m) => m.forward_with_cache(&[next], &mut cache),
|
||||
@@ -151,7 +187,9 @@ fn sample_greedy_last(logits: &xserv_tensor::Tensor) -> u32 {
|
||||
let seq_len = logits.shape()[0];
|
||||
let data = logits_cpu.as_slice::<bf16>();
|
||||
let last = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size];
|
||||
last.iter().enumerate()
|
||||
last.iter()
|
||||
.enumerate()
|
||||
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
|
||||
.map(|(i, _)| i as u32).unwrap()
|
||||
.map(|(i, _)| i as u32)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
@@ -88,23 +88,33 @@ impl ModelConfig {
|
||||
}
|
||||
|
||||
pub fn hidden(&self) -> usize {
|
||||
self.hidden_size.or(self.n_embd).expect("hidden_size or n_embd required")
|
||||
self.hidden_size
|
||||
.or(self.n_embd)
|
||||
.expect("hidden_size or n_embd required")
|
||||
}
|
||||
|
||||
pub fn num_heads(&self) -> usize {
|
||||
self.num_attention_heads.or(self.n_head).expect("num_attention_heads or n_head required")
|
||||
self.num_attention_heads
|
||||
.or(self.n_head)
|
||||
.expect("num_attention_heads or n_head required")
|
||||
}
|
||||
|
||||
pub fn num_layers(&self) -> usize {
|
||||
self.num_hidden_layers.or(self.n_layer).expect("num_hidden_layers or n_layer required")
|
||||
self.num_hidden_layers
|
||||
.or(self.n_layer)
|
||||
.expect("num_hidden_layers or n_layer required")
|
||||
}
|
||||
|
||||
pub fn max_seq_len(&self) -> usize {
|
||||
self.max_position_embeddings.or(self.n_positions).unwrap_or(2048)
|
||||
self.max_position_embeddings
|
||||
.or(self.n_positions)
|
||||
.unwrap_or(2048)
|
||||
}
|
||||
|
||||
pub fn ffn_hidden(&self) -> usize {
|
||||
self.intermediate_size.or(self.n_inner).unwrap_or(self.hidden() * 4)
|
||||
self.intermediate_size
|
||||
.or(self.n_inner)
|
||||
.unwrap_or(self.hidden() * 4)
|
||||
}
|
||||
|
||||
pub fn num_kv_heads(&self) -> usize {
|
||||
@@ -112,7 +122,8 @@ impl ModelConfig {
|
||||
}
|
||||
|
||||
pub fn head_dim(&self) -> usize {
|
||||
self.explicit_head_dim.unwrap_or_else(|| self.hidden() / self.num_heads())
|
||||
self.explicit_head_dim
|
||||
.unwrap_or_else(|| self.hidden() / self.num_heads())
|
||||
}
|
||||
|
||||
pub fn ln_eps(&self) -> f32 {
|
||||
|
||||
@@ -199,127 +199,296 @@ impl DecodeGraphState {
|
||||
let cublas = cublas_handle();
|
||||
|
||||
// Set cuBLAS to use our stream
|
||||
unsafe { dispatch::set_cublas_stream(cublas, s); }
|
||||
unsafe {
|
||||
dispatch::set_cublas_stream(cublas, s);
|
||||
}
|
||||
|
||||
for (l, lw) in layers.iter().enumerate() {
|
||||
// === Pre-attention graph ===
|
||||
self.pre_attn_graphs[l].begin_capture(&self.stream).expect("begin pre-attn capture");
|
||||
self.pre_attn_graphs[l]
|
||||
.begin_capture(&self.stream)
|
||||
.expect("begin pre-attn capture");
|
||||
unsafe {
|
||||
// RMSNorm
|
||||
dispatch::rmsnorm_bf16(
|
||||
self.buffers.x.as_ptr() as _, lw.input_norm, self.buffers.normed.as_mut_ptr() as _,
|
||||
1, h, eps, s,
|
||||
self.buffers.x.as_ptr() as _,
|
||||
lw.input_norm,
|
||||
self.buffers.normed.as_mut_ptr() as _,
|
||||
1,
|
||||
h,
|
||||
eps,
|
||||
s,
|
||||
);
|
||||
|
||||
// Q projection (GEMV)
|
||||
dispatch::gemv_bf16(
|
||||
self.buffers.normed.as_ptr() as _, lw.q_proj_wt, self.buffers.q_proj.as_mut_ptr() as _,
|
||||
self.buffers.normed.as_ptr() as _,
|
||||
lw.q_proj_wt,
|
||||
self.buffers.q_proj.as_mut_ptr() as _,
|
||||
self.buffers.fp32_q.as_mut_ptr() as _,
|
||||
h, nh * hd, s,
|
||||
h,
|
||||
nh * hd,
|
||||
s,
|
||||
);
|
||||
|
||||
// K projection (GEMV)
|
||||
dispatch::gemv_bf16(
|
||||
self.buffers.normed.as_ptr() as _, lw.k_proj_wt, self.buffers.k_proj.as_mut_ptr() as _,
|
||||
self.buffers.normed.as_ptr() as _,
|
||||
lw.k_proj_wt,
|
||||
self.buffers.k_proj.as_mut_ptr() as _,
|
||||
self.buffers.fp32_kv.as_mut_ptr() as _,
|
||||
h, nkv * hd, s,
|
||||
h,
|
||||
nkv * hd,
|
||||
s,
|
||||
);
|
||||
|
||||
// V projection (GEMV)
|
||||
dispatch::gemv_bf16(
|
||||
self.buffers.normed.as_ptr() as _, lw.v_proj_wt, self.buffers.v_proj.as_mut_ptr() as _,
|
||||
self.buffers.normed.as_ptr() as _,
|
||||
lw.v_proj_wt,
|
||||
self.buffers.v_proj.as_mut_ptr() as _,
|
||||
self.buffers.fp32_kv.as_mut_ptr() as _,
|
||||
h, nkv * hd, s,
|
||||
h,
|
||||
nkv * hd,
|
||||
s,
|
||||
);
|
||||
|
||||
// Reshape heads: [1, H*D] -> [1, H, 1, D]
|
||||
dispatch::reshape_heads_bf16(self.buffers.q_proj.as_ptr() as _, self.buffers.q_reshaped.as_mut_ptr() as _, 1, nh, hd, s);
|
||||
dispatch::reshape_heads_bf16(self.buffers.k_proj.as_ptr() as _, self.buffers.k_reshaped.as_mut_ptr() as _, 1, nkv, hd, s);
|
||||
dispatch::reshape_heads_bf16(self.buffers.v_proj.as_ptr() as _, self.buffers.v_reshaped.as_mut_ptr() as _, 1, nkv, hd, s);
|
||||
dispatch::reshape_heads_bf16(
|
||||
self.buffers.q_proj.as_ptr() as _,
|
||||
self.buffers.q_reshaped.as_mut_ptr() as _,
|
||||
1,
|
||||
nh,
|
||||
hd,
|
||||
s,
|
||||
);
|
||||
dispatch::reshape_heads_bf16(
|
||||
self.buffers.k_proj.as_ptr() as _,
|
||||
self.buffers.k_reshaped.as_mut_ptr() as _,
|
||||
1,
|
||||
nkv,
|
||||
hd,
|
||||
s,
|
||||
);
|
||||
dispatch::reshape_heads_bf16(
|
||||
self.buffers.v_proj.as_ptr() as _,
|
||||
self.buffers.v_reshaped.as_mut_ptr() as _,
|
||||
1,
|
||||
nkv,
|
||||
hd,
|
||||
s,
|
||||
);
|
||||
|
||||
// QK norm (head-level rmsnorm: treat [1,H,1,D] as [H, D])
|
||||
dispatch::rmsnorm_bf16(self.buffers.q_reshaped.as_ptr() as _, lw.q_norm, self.buffers.q_normed.as_mut_ptr() as _, nh, hd, eps, s);
|
||||
dispatch::rmsnorm_bf16(self.buffers.k_reshaped.as_ptr() as _, lw.k_norm, self.buffers.k_normed.as_mut_ptr() as _, nkv, hd, eps, s);
|
||||
dispatch::rmsnorm_bf16(
|
||||
self.buffers.q_reshaped.as_ptr() as _,
|
||||
lw.q_norm,
|
||||
self.buffers.q_normed.as_mut_ptr() as _,
|
||||
nh,
|
||||
hd,
|
||||
eps,
|
||||
s,
|
||||
);
|
||||
dispatch::rmsnorm_bf16(
|
||||
self.buffers.k_reshaped.as_ptr() as _,
|
||||
lw.k_norm,
|
||||
self.buffers.k_normed.as_mut_ptr() as _,
|
||||
nkv,
|
||||
hd,
|
||||
eps,
|
||||
s,
|
||||
);
|
||||
|
||||
// Transpose for RoPE: [1,H,1,D] -> [1,H,D]
|
||||
dispatch::transpose_hsd_to_shd_bf16(self.buffers.q_normed.as_ptr() as _, self.buffers.q_rope.as_mut_ptr() as _, 1, nh, hd, s);
|
||||
dispatch::transpose_hsd_to_shd_bf16(self.buffers.k_normed.as_ptr() as _, self.buffers.k_rope.as_mut_ptr() as _, 1, nkv, hd, s);
|
||||
dispatch::transpose_hsd_to_shd_bf16(
|
||||
self.buffers.q_normed.as_ptr() as _,
|
||||
self.buffers.q_rope.as_mut_ptr() as _,
|
||||
1,
|
||||
nh,
|
||||
hd,
|
||||
s,
|
||||
);
|
||||
dispatch::transpose_hsd_to_shd_bf16(
|
||||
self.buffers.k_normed.as_ptr() as _,
|
||||
self.buffers.k_rope.as_mut_ptr() as _,
|
||||
1,
|
||||
nkv,
|
||||
hd,
|
||||
s,
|
||||
);
|
||||
|
||||
// RoPE (in-place, reads position_gpu)
|
||||
dispatch::rope_bf16(self.buffers.q_rope.as_mut_ptr() as _, rope_cos, rope_sin, self.buffers.position_gpu.as_ptr() as _, 1, nh, hd, s);
|
||||
dispatch::rope_bf16(self.buffers.k_rope.as_mut_ptr() as _, rope_cos, rope_sin, self.buffers.position_gpu.as_ptr() as _, 1, nkv, hd, s);
|
||||
dispatch::rope_bf16(
|
||||
self.buffers.q_rope.as_mut_ptr() as _,
|
||||
rope_cos,
|
||||
rope_sin,
|
||||
self.buffers.position_gpu.as_ptr() as _,
|
||||
1,
|
||||
nh,
|
||||
hd,
|
||||
s,
|
||||
);
|
||||
dispatch::rope_bf16(
|
||||
self.buffers.k_rope.as_mut_ptr() as _,
|
||||
rope_cos,
|
||||
rope_sin,
|
||||
self.buffers.position_gpu.as_ptr() as _,
|
||||
1,
|
||||
nkv,
|
||||
hd,
|
||||
s,
|
||||
);
|
||||
|
||||
// Transpose back: [1,H,D] -> [1,H,1,D]
|
||||
dispatch::transpose_shd_to_hsd_bf16(self.buffers.q_rope.as_ptr() as _, self.buffers.q_final.as_mut_ptr() as _, 1, nh, hd, s);
|
||||
dispatch::transpose_shd_to_hsd_bf16(self.buffers.k_rope.as_ptr() as _, self.buffers.k_final.as_mut_ptr() as _, 1, nkv, hd, s);
|
||||
dispatch::transpose_shd_to_hsd_bf16(
|
||||
self.buffers.q_rope.as_ptr() as _,
|
||||
self.buffers.q_final.as_mut_ptr() as _,
|
||||
1,
|
||||
nh,
|
||||
hd,
|
||||
s,
|
||||
);
|
||||
dispatch::transpose_shd_to_hsd_bf16(
|
||||
self.buffers.k_rope.as_ptr() as _,
|
||||
self.buffers.k_final.as_mut_ptr() as _,
|
||||
1,
|
||||
nkv,
|
||||
hd,
|
||||
s,
|
||||
);
|
||||
}
|
||||
self.pre_attn_graphs[l].end_capture(&self.stream).expect("end pre-attn capture");
|
||||
self.pre_attn_graphs[l]
|
||||
.end_capture(&self.stream)
|
||||
.expect("end pre-attn capture");
|
||||
|
||||
// === Post-attention graph ===
|
||||
self.post_attn_graphs[l].begin_capture(&self.stream).expect("begin post-attn capture");
|
||||
self.post_attn_graphs[l]
|
||||
.begin_capture(&self.stream)
|
||||
.expect("begin post-attn capture");
|
||||
unsafe {
|
||||
// Merge heads: [1,H,1,D] -> [1, hidden]
|
||||
// attn_out is written by ungraphed attention
|
||||
dispatch::merge_heads_bf16(self.buffers.attn_out.as_ptr() as _, self.buffers.attn_merged.as_mut_ptr() as _, 1, nh, hd, s);
|
||||
dispatch::merge_heads_bf16(
|
||||
self.buffers.attn_out.as_ptr() as _,
|
||||
self.buffers.attn_merged.as_mut_ptr() as _,
|
||||
1,
|
||||
nh,
|
||||
hd,
|
||||
s,
|
||||
);
|
||||
|
||||
// O projection
|
||||
dispatch::gemv_bf16(
|
||||
self.buffers.attn_merged.as_ptr() as _, lw.o_proj_wt, self.buffers.o_proj.as_mut_ptr() as _,
|
||||
self.buffers.attn_merged.as_ptr() as _,
|
||||
lw.o_proj_wt,
|
||||
self.buffers.o_proj.as_mut_ptr() as _,
|
||||
self.buffers.fp32_hidden.as_mut_ptr() as _,
|
||||
nh * hd, h, s,
|
||||
nh * hd,
|
||||
h,
|
||||
s,
|
||||
);
|
||||
|
||||
// Fused Add+RMSNorm: normed2 = rmsnorm(o_proj + x), sum_out = o_proj + x
|
||||
dispatch::add_rmsnorm_bf16(
|
||||
self.buffers.o_proj.as_ptr() as _, self.buffers.x.as_ptr() as _, lw.post_norm,
|
||||
self.buffers.normed2.as_mut_ptr() as _, self.buffers.sum_out.as_mut_ptr() as _,
|
||||
1, h, eps, s,
|
||||
self.buffers.o_proj.as_ptr() as _,
|
||||
self.buffers.x.as_ptr() as _,
|
||||
lw.post_norm,
|
||||
self.buffers.normed2.as_mut_ptr() as _,
|
||||
self.buffers.sum_out.as_mut_ptr() as _,
|
||||
1,
|
||||
h,
|
||||
eps,
|
||||
s,
|
||||
);
|
||||
|
||||
// Gate projection
|
||||
dispatch::gemv_bf16(
|
||||
self.buffers.normed2.as_ptr() as _, lw.gate_proj_wt, self.buffers.gate.as_mut_ptr() as _,
|
||||
self.buffers.normed2.as_ptr() as _,
|
||||
lw.gate_proj_wt,
|
||||
self.buffers.gate.as_mut_ptr() as _,
|
||||
self.buffers.fp32_intermediate.as_mut_ptr() as _,
|
||||
h, inter, s,
|
||||
h,
|
||||
inter,
|
||||
s,
|
||||
);
|
||||
|
||||
// Up projection
|
||||
dispatch::gemv_bf16(
|
||||
self.buffers.normed2.as_ptr() as _, lw.up_proj_wt, self.buffers.up.as_mut_ptr() as _,
|
||||
self.buffers.normed2.as_ptr() as _,
|
||||
lw.up_proj_wt,
|
||||
self.buffers.up.as_mut_ptr() as _,
|
||||
self.buffers.fp32_intermediate.as_mut_ptr() as _,
|
||||
h, inter, s,
|
||||
h,
|
||||
inter,
|
||||
s,
|
||||
);
|
||||
|
||||
// Fused SiLU x Mul
|
||||
dispatch::silu_mul_bf16(self.buffers.gate.as_ptr() as _, self.buffers.up.as_ptr() as _, self.buffers.silu_out.as_mut_ptr() as _, inter, s);
|
||||
dispatch::silu_mul_bf16(
|
||||
self.buffers.gate.as_ptr() as _,
|
||||
self.buffers.up.as_ptr() as _,
|
||||
self.buffers.silu_out.as_mut_ptr() as _,
|
||||
inter,
|
||||
s,
|
||||
);
|
||||
|
||||
// Down projection
|
||||
dispatch::gemv_bf16(
|
||||
self.buffers.silu_out.as_ptr() as _, lw.down_proj_wt, self.buffers.down.as_mut_ptr() as _,
|
||||
self.buffers.silu_out.as_ptr() as _,
|
||||
lw.down_proj_wt,
|
||||
self.buffers.down.as_mut_ptr() as _,
|
||||
self.buffers.fp32_hidden.as_mut_ptr() as _,
|
||||
inter, h, s,
|
||||
inter,
|
||||
h,
|
||||
s,
|
||||
);
|
||||
|
||||
// x = sum_out + down (residual connection for next layer)
|
||||
dispatch::add_bf16(self.buffers.sum_out.as_ptr() as _, self.buffers.down.as_ptr() as _, self.buffers.x.as_mut_ptr() as _, h, s);
|
||||
dispatch::add_bf16(
|
||||
self.buffers.sum_out.as_ptr() as _,
|
||||
self.buffers.down.as_ptr() as _,
|
||||
self.buffers.x.as_mut_ptr() as _,
|
||||
h,
|
||||
s,
|
||||
);
|
||||
}
|
||||
self.post_attn_graphs[l].end_capture(&self.stream).expect("end post-attn capture");
|
||||
self.post_attn_graphs[l]
|
||||
.end_capture(&self.stream)
|
||||
.expect("end post-attn capture");
|
||||
}
|
||||
|
||||
// === Final graph: norm + lm_head ===
|
||||
self.final_graph.begin_capture(&self.stream).expect("begin final capture");
|
||||
self.final_graph
|
||||
.begin_capture(&self.stream)
|
||||
.expect("begin final capture");
|
||||
unsafe {
|
||||
dispatch::rmsnorm_bf16(self.buffers.x.as_ptr() as _, norm_weight, self.buffers.normed.as_mut_ptr() as _, 1, h, eps, s);
|
||||
dispatch::rmsnorm_bf16(
|
||||
self.buffers.x.as_ptr() as _,
|
||||
norm_weight,
|
||||
self.buffers.normed.as_mut_ptr() as _,
|
||||
1,
|
||||
h,
|
||||
eps,
|
||||
s,
|
||||
);
|
||||
dispatch::gemv_bf16(
|
||||
self.buffers.normed.as_ptr() as _, lm_head_wt, self.buffers.logits.as_mut_ptr() as _,
|
||||
self.buffers.normed.as_ptr() as _,
|
||||
lm_head_wt,
|
||||
self.buffers.logits.as_mut_ptr() as _,
|
||||
self.buffers.fp32_vocab.as_mut_ptr() as _,
|
||||
h, vocab, s,
|
||||
h,
|
||||
vocab,
|
||||
s,
|
||||
);
|
||||
}
|
||||
self.final_graph.end_capture(&self.stream).expect("end final capture");
|
||||
self.final_graph
|
||||
.end_capture(&self.stream)
|
||||
.expect("end final capture");
|
||||
|
||||
// Reset cuBLAS back to null stream
|
||||
unsafe { dispatch::set_cublas_stream(cublas, std::ptr::null_mut()); }
|
||||
unsafe {
|
||||
dispatch::set_cublas_stream(cublas, std::ptr::null_mut());
|
||||
}
|
||||
|
||||
self.captured = true;
|
||||
}
|
||||
@@ -343,8 +512,14 @@ impl DecodeGraphState {
|
||||
let es = 2usize; // BF16
|
||||
|
||||
// Upload token ID and position to fixed GPU buffers
|
||||
self.buffers.token_id_gpu.copy_from_host(&token_id.to_le_bytes()).unwrap();
|
||||
self.buffers.position_gpu.copy_from_host(&position.to_le_bytes()).unwrap();
|
||||
self.buffers
|
||||
.token_id_gpu
|
||||
.copy_from_host(&token_id.to_le_bytes())
|
||||
.unwrap();
|
||||
self.buffers
|
||||
.position_gpu
|
||||
.copy_from_host(&position.to_le_bytes())
|
||||
.unwrap();
|
||||
|
||||
// Embedding (outside graph since token_id changes each step)
|
||||
unsafe {
|
||||
@@ -352,13 +527,18 @@ impl DecodeGraphState {
|
||||
embed_table,
|
||||
self.buffers.token_id_gpu.as_ptr() as _,
|
||||
self.buffers.x.as_mut_ptr() as _,
|
||||
1, hidden_size, vocab_size, s,
|
||||
1,
|
||||
hidden_size,
|
||||
vocab_size,
|
||||
s,
|
||||
);
|
||||
}
|
||||
|
||||
for l in 0..self.num_layers {
|
||||
// Pre-attention graph (norm + QKV + reshape + QK-norm + RoPE)
|
||||
self.pre_attn_graphs[l].launch(&self.stream).expect("launch pre-attn graph");
|
||||
self.pre_attn_graphs[l]
|
||||
.launch(&self.stream)
|
||||
.expect("launch pre-attn graph");
|
||||
|
||||
// Ungraphed: KV cache append
|
||||
// k_final shape: [1, num_kv_heads, 1, head_dim] (after RoPE pipeline)
|
||||
@@ -402,9 +582,13 @@ impl DecodeGraphState {
|
||||
k_full.data_ptr() as _,
|
||||
v_full.data_ptr() as _,
|
||||
self.buffers.attn_out.as_mut_ptr() as _,
|
||||
1, nh as i32, nkv as i32,
|
||||
kv_len, hd as i32,
|
||||
scale, s,
|
||||
1,
|
||||
nh as i32,
|
||||
nkv as i32,
|
||||
kv_len,
|
||||
hd as i32,
|
||||
scale,
|
||||
s,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -412,11 +596,15 @@ impl DecodeGraphState {
|
||||
self.stream.synchronize().expect("sync before post-attn");
|
||||
|
||||
// Post-attention graph (merge + O-proj + add_rmsnorm + FFN + residual)
|
||||
self.post_attn_graphs[l].launch(&self.stream).expect("launch post-attn graph");
|
||||
self.post_attn_graphs[l]
|
||||
.launch(&self.stream)
|
||||
.expect("launch post-attn graph");
|
||||
}
|
||||
|
||||
// Final graph (norm + lm_head)
|
||||
self.final_graph.launch(&self.stream).expect("launch final graph");
|
||||
self.final_graph
|
||||
.launch(&self.stream)
|
||||
.expect("launch final graph");
|
||||
|
||||
// Sync to ensure logits are ready
|
||||
self.stream.synchronize().expect("sync after decode");
|
||||
|
||||
@@ -42,7 +42,13 @@ pub struct KVCache {
|
||||
}
|
||||
|
||||
impl KVCache {
|
||||
pub fn new(num_layers: usize, num_heads: usize, head_dim: usize, dtype: DType, device: Device) -> Self {
|
||||
pub fn new(
|
||||
num_layers: usize,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
dtype: DType,
|
||||
device: Device,
|
||||
) -> Self {
|
||||
Self {
|
||||
k: (0..num_layers).map(|_| vec![vec![]; num_heads]).collect(),
|
||||
v: (0..num_layers).map(|_| vec![vec![]; num_heads]).collect(),
|
||||
@@ -55,10 +61,18 @@ impl KVCache {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn seq_len(&self) -> usize { self.len }
|
||||
pub fn seq_len(&self) -> usize {
|
||||
self.len
|
||||
}
|
||||
|
||||
/// Append from a CPU tensor with shape [1, H, new_tokens, D].
|
||||
pub fn append_kv_tensor(&mut self, layer: usize, k_cpu: &Tensor, v_cpu: &Tensor, new_tokens: usize) {
|
||||
pub fn append_kv_tensor(
|
||||
&mut self,
|
||||
layer: usize,
|
||||
k_cpu: &Tensor,
|
||||
v_cpu: &Tensor,
|
||||
new_tokens: usize,
|
||||
) {
|
||||
let hd = self.head_dim;
|
||||
let es = self.elem_size;
|
||||
let k_bytes = k_cpu.storage().as_cpu_bytes();
|
||||
@@ -118,7 +132,8 @@ impl GPT2 {
|
||||
pub fn from_weights(config: ModelConfig, mut w: HashMap<String, Tensor>) -> Self {
|
||||
crate::init_kernels();
|
||||
let take = |w: &mut HashMap<String, Tensor>, name: &str| -> Tensor {
|
||||
w.remove(name).unwrap_or_else(|| panic!("missing weight: {name}"))
|
||||
w.remove(name)
|
||||
.unwrap_or_else(|| panic!("missing weight: {name}"))
|
||||
};
|
||||
|
||||
let wte = take(&mut w, "wte.weight");
|
||||
@@ -147,7 +162,15 @@ impl GPT2 {
|
||||
});
|
||||
}
|
||||
|
||||
Self { config, wte, wpe, layers, ln_f_g, ln_f_b, lm_head }
|
||||
Self {
|
||||
config,
|
||||
wte,
|
||||
wpe,
|
||||
layers,
|
||||
ln_f_g,
|
||||
ln_f_b,
|
||||
lm_head,
|
||||
}
|
||||
}
|
||||
|
||||
/// Full forward pass without KV cache (for testing / correctness comparison).
|
||||
@@ -179,14 +202,22 @@ impl GPT2 {
|
||||
let head_dim = self.config.head_dim();
|
||||
|
||||
let tok_emb = embedding(&self.wte, token_ids);
|
||||
let pos_ids: Vec<u32> = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect();
|
||||
let pos_ids: Vec<u32> = (pos_offset..pos_offset + new_tokens)
|
||||
.map(|p| p as u32)
|
||||
.collect();
|
||||
let pos_emb = embedding(&self.wpe, &pos_ids);
|
||||
let mut x = add_tensors(&tok_emb, &pos_emb);
|
||||
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
x = self.transformer_block(
|
||||
layer, &x, Some((cache, layer_idx)),
|
||||
pos_offset, new_tokens, num_heads, head_dim, hidden,
|
||||
layer,
|
||||
&x,
|
||||
Some((cache, layer_idx)),
|
||||
pos_offset,
|
||||
new_tokens,
|
||||
num_heads,
|
||||
head_dim,
|
||||
hidden,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -238,7 +269,11 @@ impl GPT2 {
|
||||
|
||||
fn linear(x: &Tensor, weight: &Tensor, bias: Option<&Tensor>) -> Tensor {
|
||||
let out = matmul_2d(x, weight);
|
||||
if let Some(b) = bias { add_bias(&out, b) } else { out }
|
||||
if let Some(b) = bias {
|
||||
add_bias(&out, b)
|
||||
} else {
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
fn matmul_2d(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
@@ -277,7 +312,12 @@ fn add_bias(x: &Tensor, bias: &Tensor) -> Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
fn split_qkv(qkv: &Tensor, num_heads: usize, head_dim: usize, seq_len: usize) -> (Tensor, Tensor, Tensor) {
|
||||
fn split_qkv(
|
||||
qkv: &Tensor,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
seq_len: usize,
|
||||
) -> (Tensor, Tensor, Tensor) {
|
||||
let hidden = num_heads * head_dim;
|
||||
let qkv_cpu = qkv.to_device(Device::Cpu);
|
||||
let device = qkv.device();
|
||||
@@ -294,14 +334,21 @@ fn split_qkv(qkv: &Tensor, num_heads: usize, head_dim: usize, seq_len: usize) ->
|
||||
for h in 0..num_heads {
|
||||
let src_off = h * head_dim;
|
||||
let dst_off = (h * seq_len + s) * head_dim;
|
||||
q_data[dst_off..dst_off + head_dim].copy_from_slice(&row[src_off..src_off + head_dim]);
|
||||
k_data[dst_off..dst_off + head_dim].copy_from_slice(&row[hidden + src_off..hidden + src_off + head_dim]);
|
||||
v_data[dst_off..dst_off + head_dim].copy_from_slice(&row[2 * hidden + src_off..2 * hidden + src_off + head_dim]);
|
||||
q_data[dst_off..dst_off + head_dim]
|
||||
.copy_from_slice(&row[src_off..src_off + head_dim]);
|
||||
k_data[dst_off..dst_off + head_dim]
|
||||
.copy_from_slice(&row[hidden + src_off..hidden + src_off + head_dim]);
|
||||
v_data[dst_off..dst_off + head_dim].copy_from_slice(
|
||||
&row[2 * hidden + src_off..2 * hidden + src_off + head_dim],
|
||||
);
|
||||
}
|
||||
}
|
||||
let q = Tensor::from_slice(&q_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
let k = Tensor::from_slice(&k_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
let v = Tensor::from_slice(&v_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
let q =
|
||||
Tensor::from_slice(&q_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
let k =
|
||||
Tensor::from_slice(&k_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
let v =
|
||||
Tensor::from_slice(&v_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
(q, k, v)
|
||||
}
|
||||
DType::BF16 => {
|
||||
@@ -314,14 +361,21 @@ fn split_qkv(qkv: &Tensor, num_heads: usize, head_dim: usize, seq_len: usize) ->
|
||||
for h in 0..num_heads {
|
||||
let src_off = h * head_dim;
|
||||
let dst_off = (h * seq_len + s) * head_dim;
|
||||
q_data[dst_off..dst_off + head_dim].copy_from_slice(&row[src_off..src_off + head_dim]);
|
||||
k_data[dst_off..dst_off + head_dim].copy_from_slice(&row[hidden + src_off..hidden + src_off + head_dim]);
|
||||
v_data[dst_off..dst_off + head_dim].copy_from_slice(&row[2 * hidden + src_off..2 * hidden + src_off + head_dim]);
|
||||
q_data[dst_off..dst_off + head_dim]
|
||||
.copy_from_slice(&row[src_off..src_off + head_dim]);
|
||||
k_data[dst_off..dst_off + head_dim]
|
||||
.copy_from_slice(&row[hidden + src_off..hidden + src_off + head_dim]);
|
||||
v_data[dst_off..dst_off + head_dim].copy_from_slice(
|
||||
&row[2 * hidden + src_off..2 * hidden + src_off + head_dim],
|
||||
);
|
||||
}
|
||||
}
|
||||
let q = Tensor::from_slice(&q_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
let k = Tensor::from_slice(&k_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
let v = Tensor::from_slice(&v_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
let q =
|
||||
Tensor::from_slice(&q_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
let k =
|
||||
Tensor::from_slice(&k_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
let v =
|
||||
Tensor::from_slice(&v_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
(q, k, v)
|
||||
}
|
||||
_ => panic!("unsupported dtype {:?} in split_qkv", dtype),
|
||||
@@ -343,7 +397,8 @@ fn merge_heads(x: &Tensor, seq_len: usize, hidden: usize) -> Tensor {
|
||||
for h in 0..num_heads {
|
||||
let src_off = (h * seq_len + s) * head_dim;
|
||||
let dst_off = s * hidden + h * head_dim;
|
||||
out[dst_off..dst_off + head_dim].copy_from_slice(&src[src_off..src_off + head_dim]);
|
||||
out[dst_off..dst_off + head_dim]
|
||||
.copy_from_slice(&src[src_off..src_off + head_dim]);
|
||||
}
|
||||
}
|
||||
Tensor::from_slice(&out, &[seq_len, hidden]).to_device(device)
|
||||
@@ -355,7 +410,8 @@ fn merge_heads(x: &Tensor, seq_len: usize, hidden: usize) -> Tensor {
|
||||
for h in 0..num_heads {
|
||||
let src_off = (h * seq_len + s) * head_dim;
|
||||
let dst_off = s * hidden + h * head_dim;
|
||||
out[dst_off..dst_off + head_dim].copy_from_slice(&src[src_off..src_off + head_dim]);
|
||||
out[dst_off..dst_off + head_dim]
|
||||
.copy_from_slice(&src[src_off..src_off + head_dim]);
|
||||
}
|
||||
}
|
||||
Tensor::from_slice(&out, &[seq_len, hidden]).to_device(device)
|
||||
@@ -372,7 +428,8 @@ pub fn sample_greedy(logits: &Tensor) -> u32 {
|
||||
let vocab_size = logits.shape()[1];
|
||||
let seq_len = logits.shape()[0];
|
||||
let last_row = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size];
|
||||
last_row.iter()
|
||||
last_row
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
|
||||
.map(|(idx, _)| idx as u32)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use half::bf16;
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
use half::bf16;
|
||||
use xserv_kernels::*;
|
||||
use xserv_tensor::{Device, Tensor};
|
||||
|
||||
@@ -79,16 +79,23 @@ impl GptOss {
|
||||
crate::init_kernels();
|
||||
let dev = Device::Cuda(device);
|
||||
let take = |w: &mut HashMap<String, Tensor>, name: &str| -> Tensor {
|
||||
w.remove(name).unwrap_or_else(|| panic!("missing weight: {name}"))
|
||||
w.remove(name)
|
||||
.unwrap_or_else(|| panic!("missing weight: {name}"))
|
||||
};
|
||||
let repl = |t: Tensor| -> Tensor { t.to_device(dev) };
|
||||
// column-parallel: shard rows of [out, in], transpose → [in, out/world]
|
||||
let col = |t: Tensor| -> Tensor {
|
||||
shard_rows(&t, rank, world).to_device(dev).transpose(0, 1).contiguous()
|
||||
shard_rows(&t, rank, world)
|
||||
.to_device(dev)
|
||||
.transpose(0, 1)
|
||||
.contiguous()
|
||||
};
|
||||
// row-parallel: shard cols of [out, in], transpose → [in/world, out]
|
||||
let row = |t: Tensor| -> Tensor {
|
||||
shard_cols(&t, rank, world).to_device(dev).transpose(0, 1).contiguous()
|
||||
shard_cols(&t, rank, world)
|
||||
.to_device(dev)
|
||||
.transpose(0, 1)
|
||||
.contiguous()
|
||||
};
|
||||
// Bias sharding helpers
|
||||
let col_bias = |t: Tensor| -> Tensor { shard_1d(&t, rank, world).to_device(dev) };
|
||||
@@ -97,7 +104,9 @@ impl GptOss {
|
||||
let embed_tokens = repl(take(&mut w, "model.embed_tokens.weight"));
|
||||
let norm = repl(take(&mut w, "model.norm.weight"));
|
||||
let norm_bias = w.remove("model.norm.bias").map(|t| repl(t));
|
||||
let lm_head_t = repl(take(&mut w, "lm_head.weight")).transpose(0, 1).contiguous();
|
||||
let lm_head_t = repl(take(&mut w, "lm_head.weight"))
|
||||
.transpose(0, 1)
|
||||
.contiguous();
|
||||
|
||||
let head_dim = config.head_dim();
|
||||
let rope_theta = config.rope_theta.unwrap_or(150000.0);
|
||||
@@ -176,15 +185,30 @@ impl GptOss {
|
||||
// MXFP4 stores 4-bit weights in an FP8E4M3 byte container (same dtype
|
||||
// as FP8), so distinguish by the scale rank: FP8 scale is 1-D [E],
|
||||
// MXFP4 scale is 3-D [E, N, K/32].
|
||||
let is_mxfp4 = gate_up_scale.as_ref().map(|s| s.ndim() == 3).unwrap_or(false);
|
||||
let is_mxfp4 = gate_up_scale
|
||||
.as_ref()
|
||||
.map(|s| s.ndim() == 3)
|
||||
.unwrap_or(false);
|
||||
let is_fp8 = !is_mxfp4 && gate_up_3d.dtype() == xserv_tensor::DType::FP8E4M3;
|
||||
|
||||
let mut expert_gate_up_mxfp4: Option<(Tensor, Tensor)> = None;
|
||||
let mut expert_down_mxfp4: Option<(Tensor, Tensor)> = None;
|
||||
|
||||
let inter2 = if is_mxfp4 { gate_up_3d.shape()[1] } else { gate_up_3d.shape()[2] }; // 2*inter (N)
|
||||
let hidden = if is_mxfp4 { gate_up_3d.shape()[2] * 2 } else { gate_up_3d.shape()[1] };
|
||||
let inter = if is_mxfp4 { down_3d.shape()[2] * 2 } else { down_3d.shape()[1] };
|
||||
let inter2 = if is_mxfp4 {
|
||||
gate_up_3d.shape()[1]
|
||||
} else {
|
||||
gate_up_3d.shape()[2]
|
||||
}; // 2*inter (N)
|
||||
let hidden = if is_mxfp4 {
|
||||
gate_up_3d.shape()[2] * 2
|
||||
} else {
|
||||
gate_up_3d.shape()[1]
|
||||
};
|
||||
let inter = if is_mxfp4 {
|
||||
down_3d.shape()[2] * 2
|
||||
} else {
|
||||
down_3d.shape()[1]
|
||||
};
|
||||
|
||||
// Slice the rank's range of experts as contiguous 3D tensors on GPU
|
||||
let expert_gate_up_wt;
|
||||
@@ -199,10 +223,38 @@ impl GptOss {
|
||||
// + scales [E, N, K/32]. Slice this rank's experts (raw bytes).
|
||||
let gu_s = gate_up_scale.expect("MXFP4 model missing gate_up_proj_scale");
|
||||
let d_s = down_scale.expect("MXFP4 model missing down_proj_scale");
|
||||
let gu_packed = slice_expert_range_3d_raw(&gate_up_3d, expert_start, local_experts, inter2, hidden / 2).to_device(dev);
|
||||
let gu_scl = slice_expert_range_3d_raw(&gu_s, expert_start, local_experts, inter2, hidden / 32).to_device(dev);
|
||||
let dn_packed = slice_expert_range_3d_raw(&down_3d, expert_start, local_experts, hidden, inter / 2).to_device(dev);
|
||||
let dn_scl = slice_expert_range_3d_raw(&d_s, expert_start, local_experts, hidden, inter / 32).to_device(dev);
|
||||
let gu_packed = slice_expert_range_3d_raw(
|
||||
&gate_up_3d,
|
||||
expert_start,
|
||||
local_experts,
|
||||
inter2,
|
||||
hidden / 2,
|
||||
)
|
||||
.to_device(dev);
|
||||
let gu_scl = slice_expert_range_3d_raw(
|
||||
&gu_s,
|
||||
expert_start,
|
||||
local_experts,
|
||||
inter2,
|
||||
hidden / 32,
|
||||
)
|
||||
.to_device(dev);
|
||||
let dn_packed = slice_expert_range_3d_raw(
|
||||
&down_3d,
|
||||
expert_start,
|
||||
local_experts,
|
||||
hidden,
|
||||
inter / 2,
|
||||
)
|
||||
.to_device(dev);
|
||||
let dn_scl = slice_expert_range_3d_raw(
|
||||
&d_s,
|
||||
expert_start,
|
||||
local_experts,
|
||||
hidden,
|
||||
inter / 32,
|
||||
)
|
||||
.to_device(dev);
|
||||
expert_gate_up_mxfp4 = Some((gu_packed, gu_scl));
|
||||
expert_down_mxfp4 = Some((dn_packed, dn_scl));
|
||||
expert_gate_up_fp8 = None;
|
||||
@@ -214,36 +266,65 @@ impl GptOss {
|
||||
} else if is_fp8 {
|
||||
// FP8 W8A8 path: load and TRANSPOSE weights for cuBLASLt (requires transA=T on Blackwell).
|
||||
// Original: [E, K, N] → Transposed: [E, N, K]
|
||||
let gu_sliced = slice_expert_range_3d_raw(&gate_up_3d, expert_start, local_experts, hidden, inter2);
|
||||
let dn_sliced = slice_expert_range_3d_raw(&down_3d, expert_start, local_experts, inter, hidden);
|
||||
expert_gate_up_fp8 = Some(transpose_3d_inner_raw(&gu_sliced, local_experts, hidden, inter2).to_device(dev));
|
||||
expert_down_fp8 = Some(transpose_3d_inner_raw(&dn_sliced, local_experts, inter, hidden).to_device(dev));
|
||||
let gu_sliced = slice_expert_range_3d_raw(
|
||||
&gate_up_3d,
|
||||
expert_start,
|
||||
local_experts,
|
||||
hidden,
|
||||
inter2,
|
||||
);
|
||||
let dn_sliced =
|
||||
slice_expert_range_3d_raw(&down_3d, expert_start, local_experts, inter, hidden);
|
||||
expert_gate_up_fp8 = Some(
|
||||
transpose_3d_inner_raw(&gu_sliced, local_experts, hidden, inter2)
|
||||
.to_device(dev),
|
||||
);
|
||||
expert_down_fp8 = Some(
|
||||
transpose_3d_inner_raw(&dn_sliced, local_experts, inter, hidden).to_device(dev),
|
||||
);
|
||||
// Scales: [num_experts] F32 → slice to [local_experts]
|
||||
let gu_s = gate_up_scale.expect("FP8 model missing gate_up_proj_scale");
|
||||
let d_s = down_scale.expect("FP8 model missing down_proj_scale");
|
||||
expert_gate_up_scale_gpu = Some(slice_scale_range(&gu_s, expert_start, local_experts).to_device(dev));
|
||||
expert_down_scale_gpu = Some(slice_scale_range(&d_s, expert_start, local_experts).to_device(dev));
|
||||
expert_gate_up_scale_gpu =
|
||||
Some(slice_scale_range(&gu_s, expert_start, local_experts).to_device(dev));
|
||||
expert_down_scale_gpu =
|
||||
Some(slice_scale_range(&d_s, expert_start, local_experts).to_device(dev));
|
||||
// Dummy BF16 tensors (never read in FP8 path)
|
||||
expert_gate_up_wt = Tensor::empty(&[1, 1, 1], xserv_tensor::DType::BF16, dev);
|
||||
expert_down_wt = Tensor::empty(&[1, 1, 1], xserv_tensor::DType::BF16, dev);
|
||||
} else {
|
||||
// BF16 path: existing behavior
|
||||
expert_gate_up_wt = slice_expert_range_3d(&gate_up_3d, expert_start, local_experts, hidden, inter2).to_device(dev);
|
||||
expert_down_wt = slice_expert_range_3d(&down_3d, expert_start, local_experts, inter, hidden).to_device(dev);
|
||||
expert_gate_up_wt =
|
||||
slice_expert_range_3d(&gate_up_3d, expert_start, local_experts, hidden, inter2)
|
||||
.to_device(dev);
|
||||
expert_down_wt =
|
||||
slice_expert_range_3d(&down_3d, expert_start, local_experts, inter, hidden)
|
||||
.to_device(dev);
|
||||
expert_gate_up_fp8 = None;
|
||||
expert_gate_up_scale_gpu = None;
|
||||
expert_down_fp8 = None;
|
||||
expert_down_scale_gpu = None;
|
||||
}
|
||||
let expert_gate_up_bias = slice_expert_range_2d(&gate_up_bias_2d, expert_start, local_experts, inter2).to_device(dev);
|
||||
let expert_down_bias = slice_expert_range_2d(&down_bias_2d, expert_start, local_experts, hidden).to_device(dev);
|
||||
let expert_gate_up_bias =
|
||||
slice_expert_range_2d(&gate_up_bias_2d, expert_start, local_experts, inter2)
|
||||
.to_device(dev);
|
||||
let expert_down_bias =
|
||||
slice_expert_range_2d(&down_bias_2d, expert_start, local_experts, hidden)
|
||||
.to_device(dev);
|
||||
|
||||
xserv_cuda::allocator::cached_trim();
|
||||
|
||||
let input_norm = repl(take(&mut w, &format!("{p}.input_layernorm.weight")));
|
||||
let input_norm_bias = w.remove(&format!("{p}.input_layernorm.bias")).map(|t| repl(t));
|
||||
let post_norm = repl(take(&mut w, &format!("{p}.post_attention_layernorm.weight")));
|
||||
let post_norm_bias = w.remove(&format!("{p}.post_attention_layernorm.bias")).map(|t| repl(t));
|
||||
let input_norm_bias = w
|
||||
.remove(&format!("{p}.input_layernorm.bias"))
|
||||
.map(|t| repl(t));
|
||||
let post_norm = repl(take(
|
||||
&mut w,
|
||||
&format!("{p}.post_attention_layernorm.weight"),
|
||||
));
|
||||
let post_norm_bias = w
|
||||
.remove(&format!("{p}.post_attention_layernorm.bias"))
|
||||
.map(|t| repl(t));
|
||||
|
||||
layers.push(GptOssBlock {
|
||||
input_norm,
|
||||
@@ -283,17 +364,27 @@ impl GptOss {
|
||||
let local_num_kv_heads = config.num_kv_heads() / world;
|
||||
|
||||
let has_norm_bias = norm_bias.is_some();
|
||||
let is_fp8 = layers.first().map(|l| l.expert_gate_up_fp8.is_some()).unwrap_or(false);
|
||||
let is_mxfp4 = layers.first().map(|l| l.expert_gate_up_mxfp4.is_some()).unwrap_or(false);
|
||||
let is_fp8 = layers
|
||||
.first()
|
||||
.map(|l| l.expert_gate_up_fp8.is_some())
|
||||
.unwrap_or(false);
|
||||
let is_mxfp4 = layers
|
||||
.first()
|
||||
.map(|l| l.expert_gate_up_mxfp4.is_some())
|
||||
.unwrap_or(false);
|
||||
if rank == 0 {
|
||||
if has_norm_bias {
|
||||
eprintln!("gpt-oss: detected LayerNorm bias — using LayerNorm instead of RMSNorm");
|
||||
}
|
||||
if is_fp8 {
|
||||
eprintln!("gpt-oss: FP8 E4M3 quantized expert weights detected (W8A8 cuBLASLt mode)");
|
||||
eprintln!(
|
||||
"gpt-oss: FP8 E4M3 quantized expert weights detected (W8A8 cuBLASLt mode)"
|
||||
);
|
||||
}
|
||||
if is_mxfp4 {
|
||||
eprintln!("gpt-oss: MXFP4 quantized expert weights detected (W4A16 fused-GEMV mode)");
|
||||
eprintln!(
|
||||
"gpt-oss: MXFP4 quantized expert weights detected (W4A16 fused-GEMV mode)"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -341,7 +432,13 @@ impl GptOss {
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn add_norm(x: &Tensor, residual: &Tensor, weight: &Tensor, bias: &Option<Tensor>, eps: f32) -> (Tensor, Tensor) {
|
||||
fn add_norm(
|
||||
x: &Tensor,
|
||||
residual: &Tensor,
|
||||
weight: &Tensor,
|
||||
bias: &Option<Tensor>,
|
||||
eps: f32,
|
||||
) -> (Tensor, Tensor) {
|
||||
match bias {
|
||||
Some(b) => {
|
||||
let sum = xserv_kernels::add(x, residual);
|
||||
@@ -439,7 +536,6 @@ impl GptOss {
|
||||
let k_all = add_bias(&matmul_2d(&normed, &layer.k_proj_wt), &layer.k_proj_bias);
|
||||
let v_all = add_bias(&matmul_2d(&normed, &layer.v_proj_wt), &layer.v_proj_bias);
|
||||
|
||||
|
||||
// Reshape for RoPE: [B, H*D] → [B, H, D]
|
||||
let q_3d = q_all.reshape(&[batch, num_heads, head_dim]);
|
||||
let k_3d = k_all.reshape(&[batch, num_kv_heads, head_dim]);
|
||||
@@ -460,9 +556,17 @@ impl GptOss {
|
||||
let sinks_ptr = layer.sinks.data_ptr() as *const c_void;
|
||||
|
||||
let attn_out = paged_decode_attention_sinks(
|
||||
&q_4d, k_pool_ptr, v_pool_ptr, bt_ptr, cl_ptr,
|
||||
&q_4d,
|
||||
k_pool_ptr,
|
||||
v_pool_ptr,
|
||||
bt_ptr,
|
||||
cl_ptr,
|
||||
sinks_ptr,
|
||||
batch, num_heads, num_kv_heads, head_dim, max_blocks,
|
||||
batch,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
max_blocks,
|
||||
layer.window_size,
|
||||
);
|
||||
|
||||
@@ -471,9 +575,14 @@ impl GptOss {
|
||||
self.all_reduce(&attn_proj);
|
||||
let attn_proj = add_bias(&attn_proj, &layer.o_proj_bias);
|
||||
|
||||
|
||||
// Residual + post-norm
|
||||
let (normed, x_new) = Self::add_norm(&attn_proj, &residual, &layer.post_norm, &layer.post_norm_bias, eps);
|
||||
let (normed, x_new) = Self::add_norm(
|
||||
&attn_proj,
|
||||
&residual,
|
||||
&layer.post_norm,
|
||||
&layer.post_norm_bias,
|
||||
eps,
|
||||
);
|
||||
|
||||
let residual = x_new;
|
||||
let normed = normed.contiguous();
|
||||
@@ -505,7 +614,9 @@ impl GptOss {
|
||||
paged_cache.advance_seq_len(slot, new_tokens);
|
||||
|
||||
let mut x = embedding(&self.embed_tokens, token_ids);
|
||||
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect();
|
||||
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens)
|
||||
.map(|p| p as u32)
|
||||
.collect();
|
||||
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
let residual = x.clone();
|
||||
@@ -532,14 +643,21 @@ impl GptOss {
|
||||
let (k_full, v_full) = paged_cache.gather_kv_contiguous(slot, layer_idx);
|
||||
|
||||
// Flash attention with gpt-oss sinks + (per-layer) sliding window.
|
||||
let attn_out = flash_attention_sinks(&q, &k_full, &v_full, &layer.sinks, layer.window_size);
|
||||
let attn_out =
|
||||
flash_attention_sinks(&q, &k_full, &v_full, &layer.sinks, layer.window_size);
|
||||
|
||||
let attn_merged = merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
|
||||
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
|
||||
self.all_reduce(&attn_proj);
|
||||
let attn_proj = add_bias(&attn_proj, &layer.o_proj_bias);
|
||||
|
||||
let (normed, x_new) = Self::add_norm(&attn_proj, &residual, &layer.post_norm, &layer.post_norm_bias, eps);
|
||||
let (normed, x_new) = Self::add_norm(
|
||||
&attn_proj,
|
||||
&residual,
|
||||
&layer.post_norm,
|
||||
&layer.post_norm_bias,
|
||||
eps,
|
||||
);
|
||||
let residual = x_new;
|
||||
|
||||
// MoE MLP
|
||||
@@ -566,15 +684,11 @@ impl GptOss {
|
||||
let expert_start = rank * local_experts;
|
||||
|
||||
// 1. Router: [tokens, hidden] @ [hidden, num_experts] + bias → [tokens, num_experts]
|
||||
let router_logits = add_bias(
|
||||
&matmul_2d(x, &layer.router_wt),
|
||||
&layer.router_bias,
|
||||
);
|
||||
let router_logits = add_bias(&matmul_2d(x, &layer.router_wt), &layer.router_bias);
|
||||
|
||||
// 2. GPU top-k + softmax
|
||||
let (topk_ids, topk_weights) = xserv_kernels::moe::moe_topk_softmax(
|
||||
&router_logits, num_experts, top_k,
|
||||
);
|
||||
let (topk_ids, topk_weights) =
|
||||
xserv_kernels::moe::moe_topk_softmax(&router_logits, num_experts, top_k);
|
||||
|
||||
// Sparse decode path: compute ONLY the routed experts. The dense path
|
||||
// below reads every local expert's weights per forward; the sparse
|
||||
@@ -588,15 +702,31 @@ impl GptOss {
|
||||
let n = packed.shape()[1];
|
||||
let k = packed.shape()[2] * 2;
|
||||
xserv_kernels::moe::moe_sparse_gemv_mxfp4(
|
||||
x, packed, scales, &layer.expert_gate_up_bias, &topk_ids,
|
||||
num_tokens, top_k, n, k, expert_start, local_experts, false,
|
||||
x,
|
||||
packed,
|
||||
scales,
|
||||
&layer.expert_gate_up_bias,
|
||||
&topk_ids,
|
||||
num_tokens,
|
||||
top_k,
|
||||
n,
|
||||
k,
|
||||
expert_start,
|
||||
local_experts,
|
||||
false,
|
||||
)
|
||||
} else {
|
||||
xserv_kernels::moe::moe_sparse_gemv_fp8(
|
||||
x, layer.expert_gate_up_fp8.as_ref().unwrap(),
|
||||
x,
|
||||
layer.expert_gate_up_fp8.as_ref().unwrap(),
|
||||
layer.expert_gate_up_scale.as_ref().unwrap(),
|
||||
&layer.expert_gate_up_bias, &topk_ids,
|
||||
num_tokens, top_k, expert_start, local_experts, false,
|
||||
&layer.expert_gate_up_bias,
|
||||
&topk_ids,
|
||||
num_tokens,
|
||||
top_k,
|
||||
expert_start,
|
||||
local_experts,
|
||||
false,
|
||||
)
|
||||
};
|
||||
|
||||
@@ -611,20 +741,40 @@ impl GptOss {
|
||||
let n = packed.shape()[1];
|
||||
let k = packed.shape()[2] * 2;
|
||||
xserv_kernels::moe::moe_sparse_gemv_mxfp4(
|
||||
&activated, packed, scales, &layer.expert_down_bias, &topk_ids,
|
||||
num_tokens, top_k, n, k, expert_start, local_experts, true,
|
||||
&activated,
|
||||
packed,
|
||||
scales,
|
||||
&layer.expert_down_bias,
|
||||
&topk_ids,
|
||||
num_tokens,
|
||||
top_k,
|
||||
n,
|
||||
k,
|
||||
expert_start,
|
||||
local_experts,
|
||||
true,
|
||||
)
|
||||
} else {
|
||||
xserv_kernels::moe::moe_sparse_gemv_fp8(
|
||||
&activated, layer.expert_down_fp8.as_ref().unwrap(),
|
||||
&activated,
|
||||
layer.expert_down_fp8.as_ref().unwrap(),
|
||||
layer.expert_down_scale.as_ref().unwrap(),
|
||||
&layer.expert_down_bias, &topk_ids,
|
||||
num_tokens, top_k, expert_start, local_experts, true,
|
||||
&layer.expert_down_bias,
|
||||
&topk_ids,
|
||||
num_tokens,
|
||||
top_k,
|
||||
expert_start,
|
||||
local_experts,
|
||||
true,
|
||||
)
|
||||
};
|
||||
|
||||
let moe_out = xserv_kernels::moe::moe_weighted_sum_sparse(
|
||||
&down, &topk_ids, &topk_weights, expert_start, local_experts,
|
||||
&down,
|
||||
&topk_ids,
|
||||
&topk_weights,
|
||||
expert_start,
|
||||
local_experts,
|
||||
);
|
||||
self.all_reduce(&moe_out);
|
||||
return moe_out;
|
||||
@@ -644,14 +794,24 @@ impl GptOss {
|
||||
xserv_kernels::quantization::batched_gemv_mxfp4(&x2, packed, scales, n, k)
|
||||
.reshape(&[local_experts, 1, n])
|
||||
} else {
|
||||
let w_bf16 = xserv_kernels::quantization::dequant_mxfp4_to_bf16_t(packed, scales, local_experts, n, k);
|
||||
let w_bf16 = xserv_kernels::quantization::dequant_mxfp4_to_bf16_t(
|
||||
packed,
|
||||
scales,
|
||||
local_experts,
|
||||
n,
|
||||
k,
|
||||
);
|
||||
xserv_kernels::moe::batched_gemm_strided(&x_rep, &w_bf16)
|
||||
}
|
||||
} else if let Some(ref wt_fp8_t) = layer.expert_gate_up_fp8 {
|
||||
// W8A8: quantize activations with per-expert scalar scale, use cuBLASLt FP8 GEMM
|
||||
let (x_fp8, x_scales) = xserv_kernels::quantization::quantize_bf16_to_fp8_rowwise(&x_rep);
|
||||
let (x_fp8, x_scales) =
|
||||
xserv_kernels::quantization::quantize_bf16_to_fp8_rowwise(&x_rep);
|
||||
xserv_kernels::quantization::batched_gemm_fp8(
|
||||
&x_fp8, &x_scales, wt_fp8_t, layer.expert_gate_up_scale.as_ref().unwrap(),
|
||||
&x_fp8,
|
||||
&x_scales,
|
||||
wt_fp8_t,
|
||||
layer.expert_gate_up_scale.as_ref().unwrap(),
|
||||
)
|
||||
} else {
|
||||
xserv_kernels::moe::batched_gemm_strided(&x_rep, &layer.expert_gate_up_wt)
|
||||
@@ -677,14 +837,24 @@ impl GptOss {
|
||||
xserv_kernels::quantization::batched_gemv_mxfp4(&a2, packed, scales, n, k)
|
||||
.reshape(&[local_experts, 1, n])
|
||||
} else {
|
||||
let w_bf16 = xserv_kernels::quantization::dequant_mxfp4_to_bf16_t(packed, scales, local_experts, n, k);
|
||||
let w_bf16 = xserv_kernels::quantization::dequant_mxfp4_to_bf16_t(
|
||||
packed,
|
||||
scales,
|
||||
local_experts,
|
||||
n,
|
||||
k,
|
||||
);
|
||||
xserv_kernels::moe::batched_gemm_strided(&activated, &w_bf16)
|
||||
}
|
||||
} else if let Some(ref wt_fp8) = layer.expert_down_fp8 {
|
||||
// W8A8: quantize post-GLU activations to FP8, use cuBLASLt FP8 GEMM
|
||||
let (act_fp8, act_scales) = xserv_kernels::quantization::quantize_bf16_to_fp8_rowwise(&activated);
|
||||
let (act_fp8, act_scales) =
|
||||
xserv_kernels::quantization::quantize_bf16_to_fp8_rowwise(&activated);
|
||||
xserv_kernels::quantization::batched_gemm_fp8(
|
||||
&act_fp8, &act_scales, wt_fp8, layer.expert_down_scale.as_ref().unwrap(),
|
||||
&act_fp8,
|
||||
&act_scales,
|
||||
wt_fp8,
|
||||
layer.expert_down_scale.as_ref().unwrap(),
|
||||
)
|
||||
} else {
|
||||
xserv_kernels::moe::batched_gemm_strided(&activated, &layer.expert_down_wt)
|
||||
@@ -695,8 +865,12 @@ impl GptOss {
|
||||
|
||||
// 9. Weighted sum across experts → [tokens, hidden]
|
||||
let moe_out = xserv_kernels::moe::moe_weighted_sum(
|
||||
&down, &topk_ids, &topk_weights,
|
||||
expert_start, local_experts, top_k,
|
||||
&down,
|
||||
&topk_ids,
|
||||
&topk_weights,
|
||||
expert_start,
|
||||
local_experts,
|
||||
top_k,
|
||||
);
|
||||
|
||||
self.all_reduce(&moe_out);
|
||||
@@ -708,9 +882,7 @@ impl GptOss {
|
||||
|
||||
/// Upload a u32 slice to a pooled GPU buffer (synchronous H2D).
|
||||
fn upload_u32(vals: &[u32]) -> xserv_cuda::GpuBuffer {
|
||||
let bytes = unsafe {
|
||||
std::slice::from_raw_parts(vals.as_ptr() as *const u8, vals.len() * 4)
|
||||
};
|
||||
let bytes = unsafe { std::slice::from_raw_parts(vals.as_ptr() as *const u8, vals.len() * 4) };
|
||||
let mut buf = xserv_cuda::allocator::cached_alloc(bytes.len()).expect("alloc u32 upload");
|
||||
buf.copy_from_host(bytes).unwrap();
|
||||
buf
|
||||
@@ -737,11 +909,16 @@ fn add_bias(x: &Tensor, bias: &Tensor) -> Tensor {
|
||||
}
|
||||
|
||||
fn shard_rows(t: &Tensor, rank: usize, world: usize) -> Tensor {
|
||||
if world == 1 { return t.clone(); }
|
||||
if world == 1 {
|
||||
return t.clone();
|
||||
}
|
||||
let shape = t.shape();
|
||||
assert_eq!(shape.len(), 2);
|
||||
let (rows, cols) = (shape[0], shape[1]);
|
||||
assert!(rows % world == 0, "rows {rows} not divisible by world {world}");
|
||||
assert!(
|
||||
rows % world == 0,
|
||||
"rows {rows} not divisible by world {world}"
|
||||
);
|
||||
let local = rows / world;
|
||||
let host = t.to_device(Device::Cpu);
|
||||
let data = host.as_slice::<bf16>();
|
||||
@@ -751,11 +928,16 @@ fn shard_rows(t: &Tensor, rank: usize, world: usize) -> Tensor {
|
||||
}
|
||||
|
||||
fn shard_cols(t: &Tensor, rank: usize, world: usize) -> Tensor {
|
||||
if world == 1 { return t.clone(); }
|
||||
if world == 1 {
|
||||
return t.clone();
|
||||
}
|
||||
let shape = t.shape();
|
||||
assert_eq!(shape.len(), 2);
|
||||
let (rows, cols) = (shape[0], shape[1]);
|
||||
assert!(cols % world == 0, "cols {cols} not divisible by world {world}");
|
||||
assert!(
|
||||
cols % world == 0,
|
||||
"cols {cols} not divisible by world {world}"
|
||||
);
|
||||
let local = cols / world;
|
||||
let c0 = rank * local;
|
||||
let host = t.to_device(Device::Cpu);
|
||||
@@ -769,11 +951,16 @@ fn shard_cols(t: &Tensor, rank: usize, world: usize) -> Tensor {
|
||||
}
|
||||
|
||||
fn shard_1d(t: &Tensor, rank: usize, world: usize) -> Tensor {
|
||||
if world == 1 { return t.clone(); }
|
||||
if world == 1 {
|
||||
return t.clone();
|
||||
}
|
||||
let shape = t.shape();
|
||||
assert_eq!(shape.len(), 1);
|
||||
let total = shape[0];
|
||||
assert!(total % world == 0, "dim {total} not divisible by world {world}");
|
||||
assert!(
|
||||
total % world == 0,
|
||||
"dim {total} not divisible by world {world}"
|
||||
);
|
||||
let local = total / world;
|
||||
let host = t.to_device(Device::Cpu);
|
||||
let data = host.as_slice::<bf16>();
|
||||
@@ -804,7 +991,13 @@ fn transpose_3d_inner_raw(t: &Tensor, batch: usize, rows: usize, cols: usize) ->
|
||||
}
|
||||
|
||||
/// Extract experts [start..start+count) from a [num_experts, rows, cols] 3D tensor (any dtype, raw bytes).
|
||||
fn slice_expert_range_3d_raw(t: &Tensor, start: usize, count: usize, rows: usize, cols: usize) -> Tensor {
|
||||
fn slice_expert_range_3d_raw(
|
||||
t: &Tensor,
|
||||
start: usize,
|
||||
count: usize,
|
||||
rows: usize,
|
||||
cols: usize,
|
||||
) -> Tensor {
|
||||
assert_eq!(t.ndim(), 3);
|
||||
let host = t.to_device(Device::Cpu);
|
||||
let elem_size = t.dtype().size_bytes();
|
||||
@@ -826,7 +1019,13 @@ fn slice_scale_range(t: &Tensor, start: usize, count: usize) -> Tensor {
|
||||
}
|
||||
|
||||
/// Extract experts [start..start+count) from a [num_experts, rows, cols] 3D tensor
|
||||
fn slice_expert_range_3d(t: &Tensor, start: usize, count: usize, rows: usize, cols: usize) -> Tensor {
|
||||
fn slice_expert_range_3d(
|
||||
t: &Tensor,
|
||||
start: usize,
|
||||
count: usize,
|
||||
rows: usize,
|
||||
cols: usize,
|
||||
) -> Tensor {
|
||||
assert_eq!(t.ndim(), 3);
|
||||
let host = t.to_device(Device::Cpu);
|
||||
let data = host.as_slice::<bf16>();
|
||||
|
||||
@@ -59,7 +59,9 @@ impl GptOssDecodeGraph {
|
||||
|
||||
model.decode_prepare(&[position], &[slot], cache);
|
||||
ids_buf.copy_from_host(&token.to_le_bytes()).unwrap();
|
||||
pos_buf.copy_from_host(&(position as u32).to_le_bytes()).unwrap();
|
||||
pos_buf
|
||||
.copy_from_host(&(position as u32).to_le_bytes())
|
||||
.unwrap();
|
||||
|
||||
// Retained warmup: run the exact step once eagerly with the quarantine
|
||||
// ON. Freed intermediates are held back instead of recycled, so the
|
||||
@@ -88,21 +90,32 @@ impl GptOssDecodeGraph {
|
||||
let logits;
|
||||
{
|
||||
let _guard = xserv_cuda::stream::push_stream(&stream);
|
||||
graph.begin_capture(&stream).expect("begin decode-graph capture");
|
||||
graph
|
||||
.begin_capture(&stream)
|
||||
.expect("begin decode-graph capture");
|
||||
logits = model.decode_core(
|
||||
ids_buf.as_ptr() as *const c_void,
|
||||
pos_buf.as_ptr() as *const c_void,
|
||||
1,
|
||||
cache,
|
||||
);
|
||||
graph.end_capture(&stream).expect("end decode-graph capture");
|
||||
graph
|
||||
.end_capture(&stream)
|
||||
.expect("end decode-graph capture");
|
||||
}
|
||||
let arena = allocator::end_retain();
|
||||
|
||||
graph.launch(&stream).expect("first decode-graph replay");
|
||||
cache.advance_seq_len(slot, 1);
|
||||
|
||||
Self { stream, graph, ids_buf, pos_buf, logits, _arena: arena }
|
||||
Self {
|
||||
stream,
|
||||
graph,
|
||||
ids_buf,
|
||||
pos_buf,
|
||||
logits,
|
||||
_arena: arena,
|
||||
}
|
||||
}
|
||||
|
||||
/// Run one decode step by replaying the captured graph.
|
||||
@@ -116,8 +129,12 @@ impl GptOssDecodeGraph {
|
||||
) -> Tensor {
|
||||
model.decode_prepare(&[position], &[slot], cache);
|
||||
self.ids_buf.copy_from_host(&token.to_le_bytes()).unwrap();
|
||||
self.pos_buf.copy_from_host(&(position as u32).to_le_bytes()).unwrap();
|
||||
self.graph.launch(&self.stream).expect("decode-graph replay");
|
||||
self.pos_buf
|
||||
.copy_from_host(&(position as u32).to_le_bytes())
|
||||
.unwrap();
|
||||
self.graph
|
||||
.launch(&self.stream)
|
||||
.expect("decode-graph replay");
|
||||
cache.advance_seq_len(slot, 1);
|
||||
// Shallow clone: the caller reads these logits before the next replay
|
||||
// rewrites the underlying buffer.
|
||||
@@ -137,8 +154,14 @@ pub struct GraphedGptOssDecoder {
|
||||
|
||||
impl GraphedGptOssDecoder {
|
||||
pub fn new() -> Self {
|
||||
let enabled = std::env::var("XSERV_DECODE_GRAPH").map(|v| v != "0").unwrap_or(true);
|
||||
Self { graph: None, eager_steps: 0, enabled }
|
||||
let enabled = std::env::var("XSERV_DECODE_GRAPH")
|
||||
.map(|v| v != "0")
|
||||
.unwrap_or(true);
|
||||
Self {
|
||||
graph: None,
|
||||
eager_steps: 0,
|
||||
enabled,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decode(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::config::ModelConfig;
|
||||
use xserv_cuda::GpuBuffer;
|
||||
use xserv_tensor::{DType, Tensor};
|
||||
use crate::config::ModelConfig;
|
||||
|
||||
/// GPU-resident KV cache. Pre-allocates max_seq_len on GPU,
|
||||
/// appends new K/V via D2D copy at offset (no CPU round-trip).
|
||||
@@ -46,17 +46,43 @@ impl GpuKVCache {
|
||||
v_staging.push(GpuBuffer::alloc(buf_size).expect("alloc KV staging V"));
|
||||
}
|
||||
|
||||
Self { k_bufs, v_bufs, k_staging, v_staging, seq_len: 0, max_seq_len, num_kv_heads, head_dim, elem_size, dtype, device }
|
||||
Self {
|
||||
k_bufs,
|
||||
v_bufs,
|
||||
k_staging,
|
||||
v_staging,
|
||||
seq_len: 0,
|
||||
max_seq_len,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
elem_size,
|
||||
dtype,
|
||||
device,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn seq_len(&self) -> usize { self.seq_len }
|
||||
pub fn max_seq_len(&self) -> usize { self.max_seq_len }
|
||||
pub fn seq_len(&self) -> usize {
|
||||
self.seq_len
|
||||
}
|
||||
pub fn max_seq_len(&self) -> usize {
|
||||
self.max_seq_len
|
||||
}
|
||||
|
||||
/// Append new K/V tensors for a given layer.
|
||||
/// k_new, v_new: [1, num_kv_heads, new_tokens, head_dim] on GPU, contiguous.
|
||||
/// `write_pos` is the sequence position to write at (caller manages this).
|
||||
pub fn append(&mut self, layer: usize, k_new: &Tensor, v_new: &Tensor, new_tokens: usize, write_pos: usize) {
|
||||
assert!(write_pos + new_tokens <= self.max_seq_len, "KV cache overflow");
|
||||
pub fn append(
|
||||
&mut self,
|
||||
layer: usize,
|
||||
k_new: &Tensor,
|
||||
v_new: &Tensor,
|
||||
new_tokens: usize,
|
||||
write_pos: usize,
|
||||
) {
|
||||
assert!(
|
||||
write_pos + new_tokens <= self.max_seq_len,
|
||||
"KV cache overflow"
|
||||
);
|
||||
let es = self.elem_size;
|
||||
let hd = self.head_dim;
|
||||
let max_s = self.max_seq_len;
|
||||
@@ -69,14 +95,23 @@ impl GpuKVCache {
|
||||
let src_off = h * new_tokens * hd * es;
|
||||
let dst_off = (h * max_s + write_pos) * hd * es;
|
||||
let count = new_tokens * hd * es;
|
||||
self.k_bufs[layer].copy_from_device_at(k_src, src_off, dst_off, count).unwrap();
|
||||
self.v_bufs[layer].copy_from_device_at(v_src, src_off, dst_off, count).unwrap();
|
||||
self.k_bufs[layer]
|
||||
.copy_from_device_at(k_src, src_off, dst_off, count)
|
||||
.unwrap();
|
||||
self.v_bufs[layer]
|
||||
.copy_from_device_at(v_src, src_off, dst_off, count)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn advance_seq_len(&mut self, new_tokens: usize) {
|
||||
self.seq_len += new_tokens;
|
||||
assert!(self.seq_len <= self.max_seq_len, "KV cache seq_len ({}) exceeds max_seq_len ({})", self.seq_len, self.max_seq_len);
|
||||
assert!(
|
||||
self.seq_len <= self.max_seq_len,
|
||||
"KV cache seq_len ({}) exceeds max_seq_len ({})",
|
||||
self.seq_len,
|
||||
self.max_seq_len
|
||||
);
|
||||
}
|
||||
|
||||
/// Get K/V cache tensors for a layer up to `seq_len` tokens: [1, num_kv_heads, seq_len, head_dim]
|
||||
@@ -86,7 +121,11 @@ impl GpuKVCache {
|
||||
}
|
||||
|
||||
pub fn get_kv_len(&mut self, layer: usize, sl: usize) -> (Tensor, Tensor) {
|
||||
assert!(sl <= self.max_seq_len, "get_kv_len: sl ({sl}) exceeds max_seq_len ({})", self.max_seq_len);
|
||||
assert!(
|
||||
sl <= self.max_seq_len,
|
||||
"get_kv_len: sl ({sl}) exceeds max_seq_len ({})",
|
||||
self.max_seq_len
|
||||
);
|
||||
let hd = self.head_dim;
|
||||
let nh = self.num_kv_heads;
|
||||
let es = self.elem_size;
|
||||
@@ -104,8 +143,12 @@ impl GpuKVCache {
|
||||
let src_off = (h * max_s) * hd * es;
|
||||
let dst_off = (h * sl) * hd * es;
|
||||
let count = sl * hd * es;
|
||||
k_stg.copy_from_device_at(k_buf, src_off, dst_off, count).unwrap();
|
||||
v_stg.copy_from_device_at(v_buf, src_off, dst_off, count).unwrap();
|
||||
k_stg
|
||||
.copy_from_device_at(k_buf, src_off, dst_off, count)
|
||||
.unwrap();
|
||||
v_stg
|
||||
.copy_from_device_at(v_buf, src_off, dst_off, count)
|
||||
.unwrap();
|
||||
}
|
||||
// Grab raw pointers before dropping the mutable borrows
|
||||
let k_ptr = k_stg.as_mut_ptr();
|
||||
@@ -117,20 +160,35 @@ impl GpuKVCache {
|
||||
// get_kv_len call overwrites the staging buffer).
|
||||
let shape = &[1usize, nh, sl, hd];
|
||||
let k = unsafe {
|
||||
tensor_from_gpu_buffer(GpuBuffer::borrow_raw(k_ptr, out_size), shape, self.dtype, self.device)
|
||||
tensor_from_gpu_buffer(
|
||||
GpuBuffer::borrow_raw(k_ptr, out_size),
|
||||
shape,
|
||||
self.dtype,
|
||||
self.device,
|
||||
)
|
||||
};
|
||||
let v = unsafe {
|
||||
tensor_from_gpu_buffer(GpuBuffer::borrow_raw(v_ptr, out_size), shape, self.dtype, self.device)
|
||||
tensor_from_gpu_buffer(
|
||||
GpuBuffer::borrow_raw(v_ptr, out_size),
|
||||
shape,
|
||||
self.dtype,
|
||||
self.device,
|
||||
)
|
||||
};
|
||||
(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a Tensor from a GpuBuffer (takes ownership).
|
||||
unsafe fn tensor_from_gpu_buffer(buf: GpuBuffer, shape: &[usize], dtype: DType, device: u32) -> Tensor {
|
||||
use xserv_tensor::storage::Storage;
|
||||
use xserv_tensor::shape::contiguous_strides;
|
||||
unsafe fn tensor_from_gpu_buffer(
|
||||
buf: GpuBuffer,
|
||||
shape: &[usize],
|
||||
dtype: DType,
|
||||
device: u32,
|
||||
) -> Tensor {
|
||||
use smallvec::SmallVec;
|
||||
use xserv_tensor::shape::contiguous_strides;
|
||||
use xserv_tensor::storage::Storage;
|
||||
|
||||
let storage = Storage::cuda(buf, device);
|
||||
Tensor::from_storage(
|
||||
@@ -146,6 +204,11 @@ unsafe fn tensor_from_gpu_buffer(buf: GpuBuffer, shape: &[usize], dtype: DType,
|
||||
///
|
||||
/// # Safety
|
||||
/// `buf` must be a valid GPU allocation with at least `product(shape) * dtype.size_bytes()` bytes.
|
||||
pub unsafe fn tensor_from_gpu_buffer_pub(buf: GpuBuffer, shape: &[usize], dtype: DType, device: u32) -> Tensor {
|
||||
pub unsafe fn tensor_from_gpu_buffer_pub(
|
||||
buf: GpuBuffer,
|
||||
shape: &[usize],
|
||||
dtype: DType,
|
||||
device: u32,
|
||||
) -> Tensor {
|
||||
tensor_from_gpu_buffer(buf, shape, dtype, device)
|
||||
}
|
||||
|
||||
@@ -11,11 +11,11 @@ pub mod sampling;
|
||||
|
||||
pub use config::ModelConfig;
|
||||
pub use decode_graph::{DecodeGraphState, LayerWeightPtrs};
|
||||
pub use gpt2::{GPT2, KVCache};
|
||||
pub use gpt_oss::GptOss;
|
||||
pub use gpt_oss_graph::{GptOssDecodeGraph, GraphedGptOssDecoder};
|
||||
pub use gpt2::{GPT2, KVCache};
|
||||
pub use kv_cache::GpuKVCache;
|
||||
pub use paged_kv_cache::{BlockAllocator, Location, PagedKVCache, BLOCK_SIZE};
|
||||
pub use paged_kv_cache::{BLOCK_SIZE, BlockAllocator, Location, PagedKVCache};
|
||||
pub use qwen3::Qwen3;
|
||||
pub use sampling::{SamplingParams, sample, sample_greedy_penalized};
|
||||
|
||||
|
||||
@@ -5,8 +5,8 @@ use std::path::Path;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
pub fn load_safetensors(path: &Path, device: Device) -> HashMap<String, Tensor> {
|
||||
let data = std::fs::read(path)
|
||||
.unwrap_or_else(|e| panic!("failed to read {}: {e}", path.display()));
|
||||
let data =
|
||||
std::fs::read(path).unwrap_or_else(|e| panic!("failed to read {}: {e}", path.display()));
|
||||
let st = SafeTensors::deserialize(&data)
|
||||
.unwrap_or_else(|e| panic!("failed to parse safetensors {}: {e}", path.display()));
|
||||
|
||||
@@ -60,7 +60,11 @@ pub fn load_model_dir(dir: &Path, device: Device) -> HashMap<String, Tensor> {
|
||||
all_tensors.extend(tensors);
|
||||
}
|
||||
|
||||
assert!(!all_tensors.is_empty(), "no safetensors files found in {}", dir.display());
|
||||
assert!(
|
||||
!all_tensors.is_empty(),
|
||||
"no safetensors files found in {}",
|
||||
dir.display()
|
||||
);
|
||||
all_tensors
|
||||
}
|
||||
|
||||
@@ -84,8 +88,6 @@ fn make_tensor(raw_bytes: &[u8], shape: &[usize], dtype: DType) -> Tensor {
|
||||
};
|
||||
Tensor::from_slice(bfs, shape)
|
||||
}
|
||||
DType::FP8E4M3 => {
|
||||
Tensor::from_raw_bytes(raw_bytes, shape, DType::FP8E4M3)
|
||||
}
|
||||
DType::FP8E4M3 => Tensor::from_raw_bytes(raw_bytes, shape, DType::FP8E4M3),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,7 +29,10 @@ impl BlockAllocator {
|
||||
for b in (1..total_blocks).rev() {
|
||||
free_stack.push(b as u32);
|
||||
}
|
||||
Self { free_stack, total: total_blocks }
|
||||
Self {
|
||||
free_stack,
|
||||
total: total_blocks,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn alloc(&mut self) -> Option<u32> {
|
||||
@@ -136,8 +139,14 @@ impl PagedKVCache {
|
||||
device: u32,
|
||||
) -> Self {
|
||||
Self::new_tp(
|
||||
config, config.num_kv_heads(), total_blocks, cpu_total_blocks,
|
||||
max_seqs, max_blocks_per_seq, dtype, device,
|
||||
config,
|
||||
config.num_kv_heads(),
|
||||
total_blocks,
|
||||
cpu_total_blocks,
|
||||
max_seqs,
|
||||
max_blocks_per_seq,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -155,7 +164,10 @@ impl PagedKVCache {
|
||||
dtype: DType,
|
||||
device: u32,
|
||||
) -> Self {
|
||||
assert!(total_blocks >= 2, "need at least 2 blocks (one is sentinel)");
|
||||
assert!(
|
||||
total_blocks >= 2,
|
||||
"need at least 2 blocks (one is sentinel)"
|
||||
);
|
||||
let num_layers = config.num_layers();
|
||||
let head_dim = config.head_dim();
|
||||
let elem_size = dtype.size_bytes();
|
||||
@@ -179,11 +191,17 @@ impl PagedKVCache {
|
||||
if cpu_total_blocks >= 2 {
|
||||
let cpu_pool_bytes = cpu_total_blocks * block_bytes;
|
||||
for _ in 0..num_layers {
|
||||
cpu_k_pools.push(PinnedBuffer::alloc(cpu_pool_bytes).expect("alloc CPU K swap pool"));
|
||||
cpu_v_pools.push(PinnedBuffer::alloc(cpu_pool_bytes).expect("alloc CPU V swap pool"));
|
||||
cpu_k_pools
|
||||
.push(PinnedBuffer::alloc(cpu_pool_bytes).expect("alloc CPU K swap pool"));
|
||||
cpu_v_pools
|
||||
.push(PinnedBuffer::alloc(cpu_pool_bytes).expect("alloc CPU V swap pool"));
|
||||
}
|
||||
}
|
||||
let cpu_allocator = BlockAllocator::new(if cpu_total_blocks >= 2 { cpu_total_blocks } else { 0 });
|
||||
let cpu_allocator = BlockAllocator::new(if cpu_total_blocks >= 2 {
|
||||
cpu_total_blocks
|
||||
} else {
|
||||
0
|
||||
});
|
||||
|
||||
let block_table_gpu =
|
||||
GpuBuffer::alloc(max_seqs * max_blocks_per_seq * std::mem::size_of::<i32>())
|
||||
@@ -220,22 +238,49 @@ impl PagedKVCache {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn num_layers(&self) -> usize { self.num_layers }
|
||||
pub fn num_kv_heads(&self) -> usize { self.num_kv_heads }
|
||||
pub fn head_dim(&self) -> usize { self.head_dim }
|
||||
pub fn dtype(&self) -> DType { self.dtype }
|
||||
pub fn max_seqs(&self) -> usize { self.max_seqs }
|
||||
pub fn max_blocks_per_seq(&self) -> usize { self.max_blocks_per_seq }
|
||||
pub fn free_blocks(&self) -> usize { self.allocator.free_count() }
|
||||
pub fn total_blocks(&self) -> usize { self.allocator.total() }
|
||||
pub fn num_layers(&self) -> usize {
|
||||
self.num_layers
|
||||
}
|
||||
pub fn num_kv_heads(&self) -> usize {
|
||||
self.num_kv_heads
|
||||
}
|
||||
pub fn head_dim(&self) -> usize {
|
||||
self.head_dim
|
||||
}
|
||||
pub fn dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
pub fn max_seqs(&self) -> usize {
|
||||
self.max_seqs
|
||||
}
|
||||
pub fn max_blocks_per_seq(&self) -> usize {
|
||||
self.max_blocks_per_seq
|
||||
}
|
||||
pub fn free_blocks(&self) -> usize {
|
||||
self.allocator.free_count()
|
||||
}
|
||||
pub fn total_blocks(&self) -> usize {
|
||||
self.allocator.total()
|
||||
}
|
||||
|
||||
pub fn k_pool(&self, layer: usize) -> &GpuBuffer { &self.k_pools[layer] }
|
||||
pub fn v_pool(&self, layer: usize) -> &GpuBuffer { &self.v_pools[layer] }
|
||||
pub fn block_table_gpu(&self) -> &GpuBuffer { &self.block_table_gpu }
|
||||
pub fn context_lens_gpu(&self) -> &GpuBuffer { &self.context_lens_gpu }
|
||||
pub fn k_pool(&self, layer: usize) -> &GpuBuffer {
|
||||
&self.k_pools[layer]
|
||||
}
|
||||
pub fn v_pool(&self, layer: usize) -> &GpuBuffer {
|
||||
&self.v_pools[layer]
|
||||
}
|
||||
pub fn block_table_gpu(&self) -> &GpuBuffer {
|
||||
&self.block_table_gpu
|
||||
}
|
||||
pub fn context_lens_gpu(&self) -> &GpuBuffer {
|
||||
&self.context_lens_gpu
|
||||
}
|
||||
|
||||
pub fn seq_len(&self, slot: usize) -> usize {
|
||||
self.seq_states[slot].as_ref().map(|s| s.seq_len).unwrap_or(0)
|
||||
self.seq_states[slot]
|
||||
.as_ref()
|
||||
.map(|s| s.seq_len)
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
pub fn is_slot_free(&self, slot: usize) -> bool {
|
||||
@@ -280,7 +325,11 @@ impl PagedKVCache {
|
||||
let state = self.seq_states[slot].as_ref().expect("unregistered slot");
|
||||
let cur = state.block_ids.len();
|
||||
let needed_total = (state.seq_len + new_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
if needed_total > cur { needed_total - cur } else { 0 }
|
||||
if needed_total > cur {
|
||||
needed_total - cur
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
/// Pre-allocate enough physical blocks in `slot` to cover positions
|
||||
@@ -290,8 +339,14 @@ impl PagedKVCache {
|
||||
let state = self.seq_states[slot].as_mut().expect("unregistered slot");
|
||||
let needed_total = (end_pos + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
while state.block_ids.len() < needed_total {
|
||||
let b = self.allocator.alloc().expect("out of blocks (caller must check)");
|
||||
assert!(state.block_ids.len() < self.max_blocks_per_seq, "block table overflow");
|
||||
let b = self
|
||||
.allocator
|
||||
.alloc()
|
||||
.expect("out of blocks (caller must check)");
|
||||
assert!(
|
||||
state.block_ids.len() < self.max_blocks_per_seq,
|
||||
"block table overflow"
|
||||
);
|
||||
state.block_ids.push(b);
|
||||
}
|
||||
}
|
||||
@@ -318,7 +373,9 @@ impl PagedKVCache {
|
||||
num_tokens: usize,
|
||||
start_pos: usize,
|
||||
) {
|
||||
if num_tokens == 0 { return; }
|
||||
if num_tokens == 0 {
|
||||
return;
|
||||
}
|
||||
// Make sure blocks exist for the target range.
|
||||
self.ensure_capacity(slot, start_pos + num_tokens);
|
||||
|
||||
@@ -328,15 +385,21 @@ impl PagedKVCache {
|
||||
|
||||
// Stage block_ids on the GPU. Pool-allocated so this is essentially
|
||||
// free after the first call (same bucket every step).
|
||||
let block_ids: Vec<i32> = self.seq_states[slot].as_ref().unwrap()
|
||||
.block_ids.iter().map(|&b| b as i32).collect();
|
||||
let block_ids: Vec<i32> = self.seq_states[slot]
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.block_ids
|
||||
.iter()
|
||||
.map(|&b| b as i32)
|
||||
.collect();
|
||||
let bytes = block_ids.len() * std::mem::size_of::<i32>();
|
||||
let mut block_ids_gpu = xserv_cuda::allocator::cached_alloc(bytes)
|
||||
.expect("alloc append block_ids");
|
||||
let block_ids_bytes = unsafe {
|
||||
std::slice::from_raw_parts(block_ids.as_ptr() as *const u8, bytes)
|
||||
};
|
||||
block_ids_gpu.copy_from_host(block_ids_bytes).expect("upload block_ids");
|
||||
let mut block_ids_gpu =
|
||||
xserv_cuda::allocator::cached_alloc(bytes).expect("alloc append block_ids");
|
||||
let block_ids_bytes =
|
||||
unsafe { std::slice::from_raw_parts(block_ids.as_ptr() as *const u8, bytes) };
|
||||
block_ids_gpu
|
||||
.copy_from_host(block_ids_bytes)
|
||||
.expect("upload block_ids");
|
||||
|
||||
let k_src = k_new.data_ptr() as *const std::ffi::c_void;
|
||||
let v_src = v_new.data_ptr() as *const std::ffi::c_void;
|
||||
@@ -345,10 +408,16 @@ impl PagedKVCache {
|
||||
|
||||
unsafe {
|
||||
xserv_kernels::reshape_and_cache_bf16(
|
||||
k_src, v_src,
|
||||
k_pool_ptr, v_pool_ptr,
|
||||
k_src,
|
||||
v_src,
|
||||
k_pool_ptr,
|
||||
v_pool_ptr,
|
||||
block_ids_gpu.as_ptr() as *const i32,
|
||||
num_tokens, nkv, hd, start_pos, bs,
|
||||
num_tokens,
|
||||
nkv,
|
||||
hd,
|
||||
start_pos,
|
||||
bs,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
@@ -378,7 +447,9 @@ impl PagedKVCache {
|
||||
v_new: &Tensor,
|
||||
batch: usize,
|
||||
) {
|
||||
if batch == 0 { return; }
|
||||
if batch == 0 {
|
||||
return;
|
||||
}
|
||||
let nkv = self.num_kv_heads;
|
||||
let hd = self.head_dim;
|
||||
debug_assert_eq!(k_new.shape(), &[batch, nkv, hd]);
|
||||
@@ -393,10 +464,17 @@ impl PagedKVCache {
|
||||
|
||||
unsafe {
|
||||
xserv_kernels::reshape_and_cache_batched_bf16(
|
||||
k_src, v_src,
|
||||
k_pool_ptr, v_pool_ptr,
|
||||
bt_ptr, cl_ptr,
|
||||
batch, nkv, hd, BLOCK_SIZE, self.max_blocks_per_seq,
|
||||
k_src,
|
||||
v_src,
|
||||
k_pool_ptr,
|
||||
v_pool_ptr,
|
||||
bt_ptr,
|
||||
cl_ptr,
|
||||
batch,
|
||||
nkv,
|
||||
hd,
|
||||
BLOCK_SIZE,
|
||||
self.max_blocks_per_seq,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
@@ -447,7 +525,10 @@ impl PagedKVCache {
|
||||
/// before advance_seq_len has run).
|
||||
pub fn sync_active_batch_with_lens(&mut self, slots: &[usize], kv_lens: &[i32]) {
|
||||
assert_eq!(slots.len(), kv_lens.len());
|
||||
assert!(slots.len() <= self.max_seqs, "active batch exceeds max_seqs");
|
||||
assert!(
|
||||
slots.len() <= self.max_seqs,
|
||||
"active batch exceeds max_seqs"
|
||||
);
|
||||
let stride = self.max_blocks_per_seq;
|
||||
for row in &mut self.block_table_host {
|
||||
*row = 0;
|
||||
@@ -456,7 +537,9 @@ impl PagedKVCache {
|
||||
*cl = 0;
|
||||
}
|
||||
for (i, &slot) in slots.iter().enumerate() {
|
||||
let s = self.seq_states[slot].as_ref().expect("unregistered slot in active batch");
|
||||
let s = self.seq_states[slot]
|
||||
.as_ref()
|
||||
.expect("unregistered slot in active batch");
|
||||
let row = &mut self.block_table_host[i * stride..(i + 1) * stride];
|
||||
for (j, b) in s.block_ids.iter().enumerate() {
|
||||
row[j] = *b as i32;
|
||||
@@ -515,8 +598,12 @@ impl PagedKVCache {
|
||||
let src_off = ((phys * nkv + h) * bs + slot_in_blk) * hd * es;
|
||||
let dst_off = (h * sl + p) * hd * es;
|
||||
let count = chunk * hd * es;
|
||||
k_dst.copy_from_device_at(k_pool, src_off, dst_off, count).unwrap();
|
||||
v_dst.copy_from_device_at(v_pool, src_off, dst_off, count).unwrap();
|
||||
k_dst
|
||||
.copy_from_device_at(k_pool, src_off, dst_off, count)
|
||||
.unwrap();
|
||||
v_dst
|
||||
.copy_from_device_at(v_pool, src_off, dst_off, count)
|
||||
.unwrap();
|
||||
}
|
||||
p += chunk;
|
||||
}
|
||||
@@ -529,16 +616,26 @@ impl PagedKVCache {
|
||||
|
||||
// ----- Swapping (vLLM-style preemption to pinned host memory) -----
|
||||
|
||||
pub fn free_cpu_blocks(&self) -> usize { self.cpu_allocator.free_count() }
|
||||
pub fn swap_enabled(&self) -> bool { !self.cpu_k_pools.is_empty() }
|
||||
pub fn free_cpu_blocks(&self) -> usize {
|
||||
self.cpu_allocator.free_count()
|
||||
}
|
||||
pub fn swap_enabled(&self) -> bool {
|
||||
!self.cpu_k_pools.is_empty()
|
||||
}
|
||||
|
||||
pub fn is_swapped(&self, slot: usize) -> bool {
|
||||
matches!(self.seq_states[slot].as_ref().map(|s| s.location), Some(Location::Cpu))
|
||||
matches!(
|
||||
self.seq_states[slot].as_ref().map(|s| s.location),
|
||||
Some(Location::Cpu)
|
||||
)
|
||||
}
|
||||
|
||||
/// Number of physical blocks currently held by `slot` (in either pool).
|
||||
pub fn block_count(&self, slot: usize) -> usize {
|
||||
self.seq_states[slot].as_ref().map(|s| s.block_ids.len()).unwrap_or(0)
|
||||
self.seq_states[slot]
|
||||
.as_ref()
|
||||
.map(|s| s.block_ids.len())
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Whether a swapped sequence at `slot` can be brought back (enough free GPU blocks).
|
||||
@@ -554,11 +651,17 @@ impl PagedKVCache {
|
||||
/// Evict `slot`'s KV from GPU to pinned host memory and free its GPU blocks.
|
||||
/// The slot stays registered (location = Cpu); the sequence is paused.
|
||||
pub fn swap_out(&mut self, slot: usize) -> Result<(), &'static str> {
|
||||
let state = self.seq_states[slot].as_ref().ok_or("swap_out: empty slot")?;
|
||||
if state.location == Location::Cpu { return Ok(()); }
|
||||
let state = self.seq_states[slot]
|
||||
.as_ref()
|
||||
.ok_or("swap_out: empty slot")?;
|
||||
if state.location == Location::Cpu {
|
||||
return Ok(());
|
||||
}
|
||||
let gpu_ids = state.block_ids.clone();
|
||||
let n = gpu_ids.len();
|
||||
if !self.cpu_allocator.can_alloc(n) { return Err("swap_out: CPU pool full"); }
|
||||
if !self.cpu_allocator.can_alloc(n) {
|
||||
return Err("swap_out: CPU pool full");
|
||||
}
|
||||
|
||||
let cpu_ids: Vec<u32> = (0..n)
|
||||
.map(|_| self.cpu_allocator.alloc().expect("checked can_alloc"))
|
||||
@@ -570,10 +673,18 @@ impl PagedKVCache {
|
||||
let g_off = gpu_ids[i] as usize * bb;
|
||||
let c_off = cpu_ids[i] as usize * bb;
|
||||
self.k_pools[layer]
|
||||
.copy_to_host_at(&mut self.cpu_k_pools[layer].as_mut_slice()[c_off..c_off + bb], g_off, bb)
|
||||
.copy_to_host_at(
|
||||
&mut self.cpu_k_pools[layer].as_mut_slice()[c_off..c_off + bb],
|
||||
g_off,
|
||||
bb,
|
||||
)
|
||||
.unwrap();
|
||||
self.v_pools[layer]
|
||||
.copy_to_host_at(&mut self.cpu_v_pools[layer].as_mut_slice()[c_off..c_off + bb], g_off, bb)
|
||||
.copy_to_host_at(
|
||||
&mut self.cpu_v_pools[layer].as_mut_slice()[c_off..c_off + bb],
|
||||
g_off,
|
||||
bb,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
@@ -589,11 +700,17 @@ impl PagedKVCache {
|
||||
|
||||
/// Bring `slot`'s KV back from host to GPU and free its CPU blocks.
|
||||
pub fn swap_in(&mut self, slot: usize) -> Result<(), &'static str> {
|
||||
let state = self.seq_states[slot].as_ref().ok_or("swap_in: empty slot")?;
|
||||
if state.location == Location::Gpu { return Ok(()); }
|
||||
let state = self.seq_states[slot]
|
||||
.as_ref()
|
||||
.ok_or("swap_in: empty slot")?;
|
||||
if state.location == Location::Gpu {
|
||||
return Ok(());
|
||||
}
|
||||
let cpu_ids = state.block_ids.clone();
|
||||
let n = cpu_ids.len();
|
||||
if !self.allocator.can_alloc(n) { return Err("swap_in: GPU pool full"); }
|
||||
if !self.allocator.can_alloc(n) {
|
||||
return Err("swap_in: GPU pool full");
|
||||
}
|
||||
|
||||
let gpu_ids: Vec<u32> = (0..n)
|
||||
.map(|_| self.allocator.alloc().expect("checked can_alloc"))
|
||||
@@ -605,10 +722,18 @@ impl PagedKVCache {
|
||||
let g_off = gpu_ids[i] as usize * bb;
|
||||
let c_off = cpu_ids[i] as usize * bb;
|
||||
self.k_pools[layer]
|
||||
.copy_from_host_at(&self.cpu_k_pools[layer].as_slice()[c_off..c_off + bb], g_off, bb)
|
||||
.copy_from_host_at(
|
||||
&self.cpu_k_pools[layer].as_slice()[c_off..c_off + bb],
|
||||
g_off,
|
||||
bb,
|
||||
)
|
||||
.unwrap();
|
||||
self.v_pools[layer]
|
||||
.copy_from_host_at(&self.cpu_v_pools[layer].as_slice()[c_off..c_off + bb], g_off, bb)
|
||||
.copy_from_host_at(
|
||||
&self.cpu_v_pools[layer].as_slice()[c_off..c_off + bb],
|
||||
g_off,
|
||||
bb,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
@@ -623,7 +748,12 @@ impl PagedKVCache {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn tensor_from_owned_buf(buf: GpuBuffer, shape: &[usize], dtype: DType, device: u32) -> Tensor {
|
||||
unsafe fn tensor_from_owned_buf(
|
||||
buf: GpuBuffer,
|
||||
shape: &[usize],
|
||||
dtype: DType,
|
||||
device: u32,
|
||||
) -> Tensor {
|
||||
use smallvec::SmallVec;
|
||||
use xserv_tensor::shape::contiguous_strides;
|
||||
use xserv_tensor::storage::Storage;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use std::collections::HashMap;
|
||||
use half::bf16;
|
||||
use std::collections::HashMap;
|
||||
use xserv_kernels::*;
|
||||
use xserv_tensor::{Device, Tensor};
|
||||
|
||||
@@ -41,9 +41,16 @@ struct Qwen3Block {
|
||||
}
|
||||
|
||||
impl Qwen3Block {
|
||||
fn q_proj_wt(&self) -> Tensor { self.qkv_proj_wt.narrow(1, 0, self.q_dim) }
|
||||
fn k_proj_wt(&self) -> Tensor { self.qkv_proj_wt.narrow(1, self.q_dim, self.kv_dim) }
|
||||
fn v_proj_wt(&self) -> Tensor { self.qkv_proj_wt.narrow(1, self.q_dim + self.kv_dim, self.kv_dim) }
|
||||
fn q_proj_wt(&self) -> Tensor {
|
||||
self.qkv_proj_wt.narrow(1, 0, self.q_dim)
|
||||
}
|
||||
fn k_proj_wt(&self) -> Tensor {
|
||||
self.qkv_proj_wt.narrow(1, self.q_dim, self.kv_dim)
|
||||
}
|
||||
fn v_proj_wt(&self) -> Tensor {
|
||||
self.qkv_proj_wt
|
||||
.narrow(1, self.q_dim + self.kv_dim, self.kv_dim)
|
||||
}
|
||||
fn gate_proj_wt(&self) -> Tensor {
|
||||
let half = self.gate_up_proj_wt.shape()[1] / 2;
|
||||
self.gate_up_proj_wt.narrow(1, 0, half)
|
||||
@@ -80,18 +87,31 @@ impl Qwen3 {
|
||||
crate::init_kernels();
|
||||
let dev = Device::Cuda(device);
|
||||
let take = |w: &mut HashMap<String, Tensor>, name: &str| -> Tensor {
|
||||
w.remove(name).unwrap_or_else(|| panic!("missing weight: {name}"))
|
||||
w.remove(name)
|
||||
.unwrap_or_else(|| panic!("missing weight: {name}"))
|
||||
};
|
||||
// Replicated weight: upload whole to this rank's device.
|
||||
let repl = |t: Tensor| -> Tensor { t.to_device(dev) };
|
||||
// column-parallel: keep this rank's rows of [out, in], upload, transpose → [in, out/world].
|
||||
let col = |t: Tensor| -> Tensor { shard_rows(&t, rank, world).to_device(dev).transpose(0, 1).contiguous() };
|
||||
let col = |t: Tensor| -> Tensor {
|
||||
shard_rows(&t, rank, world)
|
||||
.to_device(dev)
|
||||
.transpose(0, 1)
|
||||
.contiguous()
|
||||
};
|
||||
// row-parallel: keep this rank's cols of [out, in], upload, transpose → [in/world, out].
|
||||
let row = |t: Tensor| -> Tensor { shard_cols(&t, rank, world).to_device(dev).transpose(0, 1).contiguous() };
|
||||
let row = |t: Tensor| -> Tensor {
|
||||
shard_cols(&t, rank, world)
|
||||
.to_device(dev)
|
||||
.transpose(0, 1)
|
||||
.contiguous()
|
||||
};
|
||||
|
||||
let embed_tokens = repl(take(&mut w, "model.embed_tokens.weight"));
|
||||
let norm = repl(take(&mut w, "model.norm.weight"));
|
||||
let lm_head_t = repl(take(&mut w, "lm_head.weight")).transpose(0, 1).contiguous();
|
||||
let lm_head_t = repl(take(&mut w, "lm_head.weight"))
|
||||
.transpose(0, 1)
|
||||
.contiguous();
|
||||
|
||||
let rope_cache = RopeCache::new(
|
||||
config.max_seq_len(),
|
||||
@@ -102,7 +122,10 @@ impl Qwen3 {
|
||||
let num_layers = config.num_layers();
|
||||
let mut layers = Vec::with_capacity(num_layers);
|
||||
if rank == 0 {
|
||||
eprintln!("Loading+sharding weights for {} layers (world={world})...", num_layers);
|
||||
eprintln!(
|
||||
"Loading+sharding weights for {} layers (world={world})...",
|
||||
num_layers
|
||||
);
|
||||
}
|
||||
for i in 0..num_layers {
|
||||
let p = format!("model.layers.{i}");
|
||||
@@ -126,7 +149,10 @@ impl Qwen3 {
|
||||
o_proj_wt: row(take(&mut w, &format!("{p}.self_attn.o_proj.weight"))),
|
||||
q_norm: repl(take(&mut w, &format!("{p}.self_attn.q_norm.weight"))),
|
||||
k_norm: repl(take(&mut w, &format!("{p}.self_attn.k_norm.weight"))),
|
||||
post_norm: repl(take(&mut w, &format!("{p}.post_attention_layernorm.weight"))),
|
||||
post_norm: repl(take(
|
||||
&mut w,
|
||||
&format!("{p}.post_attention_layernorm.weight"),
|
||||
)),
|
||||
gate_up_proj_wt,
|
||||
down_proj_wt: row(take(&mut w, &format!("{p}.mlp.down_proj.weight"))),
|
||||
});
|
||||
@@ -165,7 +191,10 @@ impl Qwen3 {
|
||||
let dev = Device::Cuda(device);
|
||||
assert!(num_stages >= 1);
|
||||
let num_layers = config.num_layers();
|
||||
assert!(num_layers % num_stages == 0, "num_layers {num_layers} not divisible by pp {num_stages}");
|
||||
assert!(
|
||||
num_layers % num_stages == 0,
|
||||
"num_layers {num_layers} not divisible by pp {num_stages}"
|
||||
);
|
||||
let per_stage = num_layers / num_stages;
|
||||
let lo = stage * per_stage;
|
||||
let hi = lo + per_stage;
|
||||
@@ -173,16 +202,29 @@ impl Qwen3 {
|
||||
let is_last_stage = stage == num_stages - 1;
|
||||
|
||||
let take = |w: &mut HashMap<String, Tensor>, name: &str| -> Tensor {
|
||||
w.remove(name).unwrap_or_else(|| panic!("missing weight: {name}"))
|
||||
w.remove(name)
|
||||
.unwrap_or_else(|| panic!("missing weight: {name}"))
|
||||
};
|
||||
let repl = |t: Tensor| -> Tensor { t.to_device(dev) };
|
||||
// Pre-transpose like the TP path's `col`/`row` do for world==1 (no shard).
|
||||
let wt = |t: Tensor| -> Tensor { t.to_device(dev).transpose(0, 1).contiguous() };
|
||||
let placeholder = || Tensor::from_slice(&[bf16::ZERO], &[1, 1]).to_device(dev);
|
||||
|
||||
let embed_tokens = if is_first_stage { repl(take(&mut w, "model.embed_tokens.weight")) } else { placeholder() };
|
||||
let norm = if is_last_stage { repl(take(&mut w, "model.norm.weight")) } else { placeholder() };
|
||||
let lm_head_t = if is_last_stage { wt(take(&mut w, "lm_head.weight")) } else { placeholder() };
|
||||
let embed_tokens = if is_first_stage {
|
||||
repl(take(&mut w, "model.embed_tokens.weight"))
|
||||
} else {
|
||||
placeholder()
|
||||
};
|
||||
let norm = if is_last_stage {
|
||||
repl(take(&mut w, "model.norm.weight"))
|
||||
} else {
|
||||
placeholder()
|
||||
};
|
||||
let lm_head_t = if is_last_stage {
|
||||
wt(take(&mut w, "lm_head.weight"))
|
||||
} else {
|
||||
placeholder()
|
||||
};
|
||||
|
||||
let rope_cache = RopeCache::new(
|
||||
config.max_seq_len(),
|
||||
@@ -217,7 +259,10 @@ impl Qwen3 {
|
||||
o_proj_wt: wt(take(&mut w, &format!("{p}.self_attn.o_proj.weight"))),
|
||||
q_norm: repl(take(&mut w, &format!("{p}.self_attn.q_norm.weight"))),
|
||||
k_norm: repl(take(&mut w, &format!("{p}.self_attn.k_norm.weight"))),
|
||||
post_norm: repl(take(&mut w, &format!("{p}.post_attention_layernorm.weight"))),
|
||||
post_norm: repl(take(
|
||||
&mut w,
|
||||
&format!("{p}.post_attention_layernorm.weight"),
|
||||
)),
|
||||
gate_up_proj_wt,
|
||||
down_proj_wt: wt(take(&mut w, &format!("{p}.mlp.down_proj.weight"))),
|
||||
});
|
||||
@@ -252,8 +297,12 @@ impl Qwen3 {
|
||||
matmul_2d(&x, &self.lm_head_t)
|
||||
}
|
||||
|
||||
pub fn pp_is_first(&self) -> bool { self.is_first_stage }
|
||||
pub fn pp_is_last(&self) -> bool { self.is_last_stage }
|
||||
pub fn pp_is_first(&self) -> bool {
|
||||
self.is_first_stage
|
||||
}
|
||||
pub fn pp_is_last(&self) -> bool {
|
||||
self.is_last_stage
|
||||
}
|
||||
|
||||
/// PP prefill over THIS stage's layers. `x` is `[S, hidden]` (stage 0: from
|
||||
/// `embed`; otherwise received from the previous stage). Writes K/V for this
|
||||
@@ -276,7 +325,9 @@ impl Qwen3 {
|
||||
paged_cache.ensure_capacity(slot, pos_offset + new_tokens);
|
||||
paged_cache.advance_seq_len(slot, new_tokens);
|
||||
|
||||
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect();
|
||||
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens)
|
||||
.map(|p| p as u32)
|
||||
.collect();
|
||||
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
let residual = x.clone();
|
||||
@@ -285,7 +336,9 @@ impl Qwen3 {
|
||||
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt);
|
||||
let q = qkv.narrow(1, 0, layer.q_dim).contiguous();
|
||||
let k = qkv.narrow(1, layer.q_dim, layer.kv_dim).contiguous();
|
||||
let v = qkv.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous();
|
||||
let v = qkv
|
||||
.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim)
|
||||
.contiguous();
|
||||
|
||||
let q = xserv_kernels::reshape_heads_gpu(&q, new_tokens, num_heads, head_dim);
|
||||
let k = xserv_kernels::reshape_heads_gpu(&k, new_tokens, num_kv_heads, head_dim);
|
||||
@@ -305,10 +358,12 @@ impl Qwen3 {
|
||||
let (k_full, v_full) = paged_cache.gather_kv_contiguous(slot, layer_idx);
|
||||
let attn_out = flash_attention(&q, &k_full, &v_full, true);
|
||||
|
||||
let attn_merged = xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
|
||||
let attn_merged =
|
||||
xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
|
||||
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
|
||||
|
||||
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let (normed, x_new) =
|
||||
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let residual = x_new.clone();
|
||||
|
||||
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt);
|
||||
@@ -356,7 +411,9 @@ impl Qwen3 {
|
||||
let qkv_all = matmul_2d(&normed, &layer.qkv_proj_wt);
|
||||
let q_all = qkv_all.narrow(1, 0, layer.q_dim).contiguous();
|
||||
let k_all = qkv_all.narrow(1, layer.q_dim, layer.kv_dim).contiguous();
|
||||
let v_all = qkv_all.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous();
|
||||
let v_all = qkv_all
|
||||
.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim)
|
||||
.contiguous();
|
||||
|
||||
let mut q_rows: Vec<Tensor> = Vec::with_capacity(batch);
|
||||
for b in 0..batch {
|
||||
@@ -394,14 +451,23 @@ impl Qwen3 {
|
||||
let v_pool_ptr = paged_cache.v_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
|
||||
|
||||
let attn_out = xserv_kernels::paged_decode_attention(
|
||||
&q_4d, k_pool_ptr, v_pool_ptr, bt_ptr, cl_ptr,
|
||||
batch, num_heads, num_kv_heads, head_dim, max_blocks,
|
||||
&q_4d,
|
||||
k_pool_ptr,
|
||||
v_pool_ptr,
|
||||
bt_ptr,
|
||||
cl_ptr,
|
||||
batch,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
max_blocks,
|
||||
);
|
||||
|
||||
let attn_merged = attn_out.reshape(&[batch, num_heads * head_dim]);
|
||||
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
|
||||
|
||||
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let (normed, x_new) =
|
||||
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let residual = x_new.clone();
|
||||
|
||||
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt);
|
||||
@@ -441,7 +507,9 @@ impl Qwen3 {
|
||||
let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32;
|
||||
|
||||
let mut x = embedding(&self.embed_tokens, token_ids);
|
||||
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect();
|
||||
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens)
|
||||
.map(|p| p as u32)
|
||||
.collect();
|
||||
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
let residual = x.clone();
|
||||
@@ -450,7 +518,9 @@ impl Qwen3 {
|
||||
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt);
|
||||
let q = qkv.narrow(1, 0, layer.q_dim).contiguous();
|
||||
let k = qkv.narrow(1, layer.q_dim, layer.kv_dim).contiguous();
|
||||
let v = qkv.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous();
|
||||
let v = qkv
|
||||
.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim)
|
||||
.contiguous();
|
||||
|
||||
let q = reshape_heads(&q, new_tokens, num_heads, head_dim);
|
||||
let k = reshape_heads(&k, new_tokens, num_kv_heads, head_dim);
|
||||
@@ -531,7 +601,9 @@ impl Qwen3 {
|
||||
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt);
|
||||
let q_all = qkv.narrow(1, 0, layer.q_dim).contiguous();
|
||||
let k_all = qkv.narrow(1, layer.q_dim, layer.kv_dim).contiguous();
|
||||
let v_all = qkv.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous();
|
||||
let v_all = qkv
|
||||
.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim)
|
||||
.contiguous();
|
||||
|
||||
// Per-sequence: reshape, qk-norm, RoPE, KV cache, attention, merge
|
||||
let mut attn_outputs: Vec<Tensor> = Vec::with_capacity(batch);
|
||||
@@ -583,7 +655,8 @@ impl Qwen3 {
|
||||
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
|
||||
|
||||
// Fused add + rmsnorm
|
||||
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let (normed, x_new) =
|
||||
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let residual = x_new.clone();
|
||||
|
||||
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt);
|
||||
@@ -668,7 +741,9 @@ impl Qwen3 {
|
||||
|
||||
// Per-head RMSNorm on contiguous copies (narrow views are strided).
|
||||
let q_flat = q_all.contiguous().reshape(&[batch * num_heads, head_dim]);
|
||||
let k_flat = k_all.contiguous().reshape(&[batch * num_kv_heads, head_dim]);
|
||||
let k_flat = k_all
|
||||
.contiguous()
|
||||
.reshape(&[batch * num_kv_heads, head_dim]);
|
||||
let q_normed = rmsnorm(&q_flat, &layer.q_norm, eps);
|
||||
let k_normed = rmsnorm(&k_flat, &layer.k_norm, eps);
|
||||
|
||||
@@ -688,8 +763,16 @@ impl Qwen3 {
|
||||
let k_pool_ptr = paged_cache.k_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
|
||||
let v_pool_ptr = paged_cache.v_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
|
||||
let attn_out = xserv_kernels::paged_decode_attention(
|
||||
&q_4d, k_pool_ptr, v_pool_ptr, bt_ptr, cl_ptr,
|
||||
batch, num_heads, num_kv_heads, head_dim, max_blocks,
|
||||
&q_4d,
|
||||
k_pool_ptr,
|
||||
v_pool_ptr,
|
||||
bt_ptr,
|
||||
cl_ptr,
|
||||
batch,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
max_blocks,
|
||||
);
|
||||
|
||||
// attn_out shape [B, H, 1, D] is contiguous-equivalent to [B, H*D].
|
||||
@@ -697,7 +780,8 @@ impl Qwen3 {
|
||||
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
|
||||
self.all_reduce(&attn_proj); // TP: sum partial attention outputs
|
||||
|
||||
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let (normed, x_new) =
|
||||
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let residual = x_new.clone();
|
||||
|
||||
// Fused gate+up projection: one GEMV instead of two.
|
||||
@@ -743,7 +827,9 @@ impl Qwen3 {
|
||||
paged_cache.advance_seq_len(slot, new_tokens);
|
||||
|
||||
let mut x = embedding(&self.embed_tokens, token_ids);
|
||||
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect();
|
||||
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens)
|
||||
.map(|p| p as u32)
|
||||
.collect();
|
||||
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
let residual = x.clone();
|
||||
@@ -752,7 +838,9 @@ impl Qwen3 {
|
||||
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt);
|
||||
let q = qkv.narrow(1, 0, layer.q_dim).contiguous();
|
||||
let k = qkv.narrow(1, layer.q_dim, layer.kv_dim).contiguous();
|
||||
let v = qkv.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous();
|
||||
let v = qkv
|
||||
.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim)
|
||||
.contiguous();
|
||||
|
||||
let q = xserv_kernels::reshape_heads_gpu(&q, new_tokens, num_heads, head_dim);
|
||||
let k = xserv_kernels::reshape_heads_gpu(&k, new_tokens, num_kv_heads, head_dim);
|
||||
@@ -773,11 +861,13 @@ impl Qwen3 {
|
||||
let (k_full, v_full) = paged_cache.gather_kv_contiguous(slot, layer_idx);
|
||||
let attn_out = flash_attention(&q, &k_full, &v_full, true);
|
||||
|
||||
let attn_merged = xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
|
||||
let attn_merged =
|
||||
xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
|
||||
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
|
||||
self.all_reduce(&attn_proj);
|
||||
|
||||
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let (normed, x_new) =
|
||||
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let residual = x_new.clone();
|
||||
|
||||
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt);
|
||||
@@ -805,7 +895,9 @@ impl Qwen3 {
|
||||
let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32;
|
||||
|
||||
let mut x = embedding(&self.embed_tokens, token_ids);
|
||||
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect();
|
||||
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens)
|
||||
.map(|p| p as u32)
|
||||
.collect();
|
||||
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
let residual = x.clone();
|
||||
@@ -814,7 +906,9 @@ impl Qwen3 {
|
||||
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt);
|
||||
let q = qkv.narrow(1, 0, layer.q_dim).contiguous();
|
||||
let k = qkv.narrow(1, layer.q_dim, layer.kv_dim).contiguous();
|
||||
let v = qkv.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous();
|
||||
let v = qkv
|
||||
.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim)
|
||||
.contiguous();
|
||||
|
||||
let q = xserv_kernels::reshape_heads_gpu(&q, new_tokens, num_heads, head_dim);
|
||||
let k = xserv_kernels::reshape_heads_gpu(&k, new_tokens, num_kv_heads, head_dim);
|
||||
@@ -834,10 +928,12 @@ impl Qwen3 {
|
||||
let (k_full, v_full) = cache.get_kv_len(layer_idx, pos_offset + new_tokens);
|
||||
|
||||
let attn_out = flash_attention(&q, &k_full, &v_full, true);
|
||||
let attn_merged = xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
|
||||
let attn_merged =
|
||||
xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
|
||||
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
|
||||
|
||||
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let (normed, x_new) =
|
||||
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let residual = x_new.clone();
|
||||
|
||||
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt);
|
||||
@@ -856,7 +952,9 @@ impl Qwen3 {
|
||||
|
||||
/// Extract weight pointers for CUDA Graph capture.
|
||||
pub fn layer_weight_ptrs(&self) -> Vec<crate::decode_graph::LayerWeightPtrs> {
|
||||
self.layers.iter().map(|l| crate::decode_graph::LayerWeightPtrs {
|
||||
self.layers
|
||||
.iter()
|
||||
.map(|l| crate::decode_graph::LayerWeightPtrs {
|
||||
input_norm: l.input_norm.data_ptr() as *const std::ffi::c_void,
|
||||
q_proj_wt: l.q_proj_wt().data_ptr() as *const std::ffi::c_void,
|
||||
k_proj_wt: l.k_proj_wt().data_ptr() as *const std::ffi::c_void,
|
||||
@@ -868,11 +966,14 @@ impl Qwen3 {
|
||||
gate_proj_wt: l.gate_proj_wt().data_ptr() as *const std::ffi::c_void,
|
||||
up_proj_wt: l.up_proj_wt().data_ptr() as *const std::ffi::c_void,
|
||||
down_proj_wt: l.down_proj_wt.data_ptr() as *const std::ffi::c_void,
|
||||
}).collect()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get pointers needed for CUDA Graph capture.
|
||||
pub fn graph_capture_ptrs(&self) -> (
|
||||
pub fn graph_capture_ptrs(
|
||||
&self,
|
||||
) -> (
|
||||
*const std::ffi::c_void, // norm weight
|
||||
*const std::ffi::c_void, // lm_head_t
|
||||
*const std::ffi::c_void, // embed_tokens
|
||||
@@ -895,11 +996,16 @@ impl Qwen3 {
|
||||
/// (column-parallel split: split the OUTPUT dim). `world==1` returns the whole.
|
||||
/// Input must be a contiguous CPU (or device) BF16 tensor.
|
||||
fn shard_rows(t: &Tensor, rank: usize, world: usize) -> Tensor {
|
||||
if world == 1 { return t.clone(); }
|
||||
if world == 1 {
|
||||
return t.clone();
|
||||
}
|
||||
let shape = t.shape();
|
||||
assert_eq!(shape.len(), 2, "shard_rows expects 2D weight");
|
||||
let (rows, cols) = (shape[0], shape[1]);
|
||||
assert!(rows % world == 0, "rows {rows} not divisible by world {world}");
|
||||
assert!(
|
||||
rows % world == 0,
|
||||
"rows {rows} not divisible by world {world}"
|
||||
);
|
||||
let local = rows / world;
|
||||
let host = t.to_device(Device::Cpu);
|
||||
let data = host.as_slice::<bf16>();
|
||||
@@ -911,11 +1017,16 @@ fn shard_rows(t: &Tensor, rank: usize, world: usize) -> Tensor {
|
||||
/// Keep this rank's column-block of a 2D `[rows, cols]` BF16 tensor (row-parallel
|
||||
/// split: split the INPUT dim). Strided copy. `world==1` returns the whole.
|
||||
fn shard_cols(t: &Tensor, rank: usize, world: usize) -> Tensor {
|
||||
if world == 1 { return t.clone(); }
|
||||
if world == 1 {
|
||||
return t.clone();
|
||||
}
|
||||
let shape = t.shape();
|
||||
assert_eq!(shape.len(), 2, "shard_cols expects 2D weight");
|
||||
let (rows, cols) = (shape[0], shape[1]);
|
||||
assert!(cols % world == 0, "cols {cols} not divisible by world {world}");
|
||||
assert!(
|
||||
cols % world == 0,
|
||||
"cols {cols} not divisible by world {world}"
|
||||
);
|
||||
let local = cols / world;
|
||||
let c0 = rank * local;
|
||||
let host = t.to_device(Device::Cpu);
|
||||
@@ -1009,7 +1120,9 @@ fn transpose_from_rope(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: u
|
||||
}
|
||||
|
||||
fn repeat_kv(x: &Tensor, n_rep: usize) -> Tensor {
|
||||
if n_rep == 1 { return x.clone(); }
|
||||
if n_rep == 1 {
|
||||
return x.clone();
|
||||
}
|
||||
let kv_heads = x.shape()[1];
|
||||
let seq_len = x.shape()[2];
|
||||
let head_dim = x.shape()[3];
|
||||
@@ -1065,11 +1178,16 @@ fn concat_rows(rows: &[Tensor]) -> Tensor {
|
||||
let src_buf = row.storage().gpu_buffer();
|
||||
let src_offset = row.offset() * elem_size;
|
||||
let dst_offset = b * row_bytes;
|
||||
out_buf.copy_from_device_at(src_buf, src_offset, dst_offset, row_bytes).unwrap();
|
||||
out_buf
|
||||
.copy_from_device_at(src_buf, src_offset, dst_offset, row_bytes)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Wrap in a Tensor
|
||||
let device_id = match device { Device::Cuda(id) => id, _ => panic!("expected CUDA device") };
|
||||
let device_id = match device {
|
||||
Device::Cuda(id) => id,
|
||||
_ => panic!("expected CUDA device"),
|
||||
};
|
||||
unsafe {
|
||||
crate::kv_cache::tensor_from_gpu_buffer_pub(out_buf, &[batch, cols], dtype, device_id)
|
||||
}
|
||||
@@ -1082,12 +1200,15 @@ fn cat_cols(tensors: &[&Tensor]) -> Tensor {
|
||||
let dtype = tensors[0].dtype();
|
||||
let device = tensors[0].device();
|
||||
let elem = dtype.size_bytes();
|
||||
let total_cols: usize = tensors.iter().map(|t| {
|
||||
let total_cols: usize = tensors
|
||||
.iter()
|
||||
.map(|t| {
|
||||
assert_eq!(t.ndim(), 2);
|
||||
assert_eq!(t.shape()[0], rows);
|
||||
assert!(t.is_contiguous());
|
||||
t.shape()[1]
|
||||
}).sum();
|
||||
})
|
||||
.sum();
|
||||
let out = Tensor::empty(&[rows, total_cols], dtype, device);
|
||||
let dst_base = out.data_ptr() as *mut u8;
|
||||
for r in 0..rows {
|
||||
@@ -1126,7 +1247,9 @@ pub fn sample_greedy(logits: &Tensor) -> u32 {
|
||||
let seq_len = logits.shape()[0];
|
||||
let data = logits_cpu.as_slice::<bf16>();
|
||||
let last = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size];
|
||||
last.iter().enumerate()
|
||||
last.iter()
|
||||
.enumerate()
|
||||
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
|
||||
.map(|(i, _)| i as u32).unwrap()
|
||||
.map(|(i, _)| i as u32)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
@@ -11,7 +11,11 @@ pub struct SamplingParams {
|
||||
|
||||
impl Default for SamplingParams {
|
||||
fn default() -> Self {
|
||||
Self { temperature: 0.0, top_k: 0, top_p: 1.0 }
|
||||
Self {
|
||||
temperature: 0.0,
|
||||
top_k: 0,
|
||||
top_p: 1.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -134,9 +138,14 @@ pub fn sample_greedy_penalized(logits: &Tensor, recent: &[u32], penalty: f32) ->
|
||||
let seq_len = logits.shape()[0];
|
||||
let logits_cpu = logits.to_device(Device::Cpu);
|
||||
let mut last_row: Vec<f32> = match logits.dtype() {
|
||||
DType::F32 => logits_cpu.as_slice::<f32>()[(seq_len - 1) * vocab_size..seq_len * vocab_size].to_vec(),
|
||||
DType::BF16 => logits_cpu.as_slice::<bf16>()[(seq_len - 1) * vocab_size..seq_len * vocab_size]
|
||||
.iter().map(|v| v.to_f32()).collect(),
|
||||
DType::F32 => {
|
||||
logits_cpu.as_slice::<f32>()[(seq_len - 1) * vocab_size..seq_len * vocab_size].to_vec()
|
||||
}
|
||||
DType::BF16 => logits_cpu.as_slice::<bf16>()
|
||||
[(seq_len - 1) * vocab_size..seq_len * vocab_size]
|
||||
.iter()
|
||||
.map(|v| v.to_f32())
|
||||
.collect(),
|
||||
_ => panic!("unsupported dtype for sampling: {:?}", logits.dtype()),
|
||||
};
|
||||
if penalty > 1.0 {
|
||||
|
||||
@@ -72,7 +72,10 @@ impl ChatTemplate {
|
||||
let source = std::fs::read_to_string(&jinja_path)
|
||||
.unwrap_or_else(|e| panic!("failed to read {}: {e}", jinja_path.display()));
|
||||
eprintln!("[chat-template] loaded from {}", jinja_path.display());
|
||||
return Self { source, model_type: model_type.to_string() };
|
||||
return Self {
|
||||
source,
|
||||
model_type: model_type.to_string(),
|
||||
};
|
||||
}
|
||||
|
||||
// 2. Try tokenizer_config.json → chat_template field
|
||||
@@ -82,7 +85,10 @@ impl ChatTemplate {
|
||||
if let Ok(v) = serde_json::from_str::<serde_json::Value>(&data) {
|
||||
if let Some(ct) = v.get("chat_template").and_then(|v| v.as_str()) {
|
||||
eprintln!("[chat-template] loaded from tokenizer_config.json");
|
||||
return Self { source: ct.to_string(), model_type: model_type.to_string() };
|
||||
return Self {
|
||||
source: ct.to_string(),
|
||||
model_type: model_type.to_string(),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -90,7 +96,10 @@ impl ChatTemplate {
|
||||
|
||||
// 3. No template found — use empty source, will fall back to hardcoded
|
||||
eprintln!("[chat-template] no Jinja template found, using hardcoded fallback");
|
||||
Self { source: String::new(), model_type: model_type.to_string() }
|
||||
Self {
|
||||
source: String::new(),
|
||||
model_type: model_type.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn render(&self, messages: &[Message]) -> String {
|
||||
@@ -206,7 +215,10 @@ fn build_prompt_gpt_oss(messages: &[Message]) -> String {
|
||||
prompt.push_str("<|start|>system<|message|>");
|
||||
prompt.push_str("You are ChatGPT, a large language model trained by OpenAI.\n");
|
||||
prompt.push_str("Knowledge cutoff: 2024-06\n");
|
||||
prompt.push_str(&format!("Current date: {}\n\n", strftime_now("%Y-%m-%d".to_string())));
|
||||
prompt.push_str(&format!(
|
||||
"Current date: {}\n\n",
|
||||
strftime_now("%Y-%m-%d".to_string())
|
||||
));
|
||||
prompt.push_str("Reasoning: low\n\n");
|
||||
prompt.push_str("# Valid channels: analysis, commentary, final. Channel must be included for every message.");
|
||||
prompt.push_str("<|end|>");
|
||||
@@ -334,13 +346,11 @@ async fn chat_non_stream(state: Arc<AppState>, req: ChatRequest) -> Response {
|
||||
"completion_tokens": completion_token_count,
|
||||
"total_tokens": prompt_token_count + completion_token_count
|
||||
}
|
||||
})).into_response()
|
||||
}))
|
||||
.into_response()
|
||||
}
|
||||
|
||||
fn chat_stream(
|
||||
state: Arc<AppState>,
|
||||
req: ChatRequest,
|
||||
) -> Response {
|
||||
fn chat_stream(state: Arc<AppState>, req: ChatRequest) -> Response {
|
||||
let id = format!("chatcmpl-{}", Uuid::new_v4());
|
||||
let model_name = state.model_name.clone();
|
||||
let created = unix_timestamp();
|
||||
@@ -356,7 +366,8 @@ fn chat_stream(
|
||||
if prompt_tokens.len() >= max_seq_len {
|
||||
return bad_request(format!(
|
||||
"prompt is {} tokens, exceeds max_seq_len {}",
|
||||
prompt_tokens.len(), max_seq_len
|
||||
prompt_tokens.len(),
|
||||
max_seq_len
|
||||
));
|
||||
}
|
||||
let max_tokens = req.max_tokens.min(max_seq_len - prompt_tokens.len());
|
||||
@@ -413,7 +424,9 @@ fn chat_stream(
|
||||
}
|
||||
});
|
||||
|
||||
Sse::new(ReceiverStream::new(sse_rx)).keep_alive(KeepAlive::default()).into_response()
|
||||
Sse::new(ReceiverStream::new(sse_rx))
|
||||
.keep_alive(KeepAlive::default())
|
||||
.into_response()
|
||||
}
|
||||
|
||||
fn validate_request(req: &ChatRequest, model_name: &str) -> Option<Response> {
|
||||
@@ -436,8 +449,13 @@ fn validate_request(req: &ChatRequest, model_name: &str) -> Option<Response> {
|
||||
/// prior handler panicked) and returns a clean 503 instead of panicking when the
|
||||
/// engine thread is gone, so one dead engine doesn't cascade into every request.
|
||||
fn submit_to_engine(state: &AppState, req: GenerateRequest) -> Result<(), Response> {
|
||||
let sender = state.engine_sender.lock().unwrap_or_else(|e| e.into_inner());
|
||||
sender.send(req).map_err(|_| service_unavailable("inference engine is not available"))
|
||||
let sender = state
|
||||
.engine_sender
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner());
|
||||
sender
|
||||
.send(req)
|
||||
.map_err(|_| service_unavailable("inference engine is not available"))
|
||||
}
|
||||
|
||||
fn service_unavailable(message: impl Into<String>) -> Response {
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::path::Path;
|
||||
use std::sync::mpsc;
|
||||
use std::sync::Once;
|
||||
use std::sync::mpsc;
|
||||
use std::time::Instant;
|
||||
use xserv_model::{ModelConfig, PagedKVCache, Qwen3, SamplingParams, sample, BLOCK_SIZE};
|
||||
use xserv_model::loader;
|
||||
use xserv_model::{BLOCK_SIZE, ModelConfig, PagedKVCache, Qwen3, SamplingParams, sample};
|
||||
use xserv_tensor::{DType, Device};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
@@ -109,12 +109,23 @@ impl Engine {
|
||||
(total_blocks * bytes_per_block) as f64 / 1e9,
|
||||
info.free_memory as f64 / 1e9,
|
||||
);
|
||||
Self { model, config, tokenizer, max_batch_size, max_seq_len, paged_cache }
|
||||
Self {
|
||||
model,
|
||||
config,
|
||||
tokenizer,
|
||||
max_batch_size,
|
||||
max_seq_len,
|
||||
paged_cache,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tokenizer(&self) -> &Tokenizer { &self.tokenizer }
|
||||
pub fn tokenizer(&self) -> &Tokenizer {
|
||||
&self.tokenizer
|
||||
}
|
||||
|
||||
pub fn max_seq_len(&self) -> usize { self.max_seq_len }
|
||||
pub fn max_seq_len(&self) -> usize {
|
||||
self.max_seq_len
|
||||
}
|
||||
|
||||
/// Main scheduler loop. Receives requests from channel, manages concurrent sequences.
|
||||
///
|
||||
@@ -134,7 +145,8 @@ impl Engine {
|
||||
|
||||
loop {
|
||||
// Step 1: Remove finished sequences and return their slots.
|
||||
let finished_slots: Vec<usize> = running.iter()
|
||||
let finished_slots: Vec<usize> = running
|
||||
.iter()
|
||||
.filter(|s| is_finished(s))
|
||||
.filter_map(|s| s.seq_slot)
|
||||
.collect();
|
||||
@@ -147,10 +159,16 @@ impl Engine {
|
||||
// room (oldest first). They resume decoding from where they paused.
|
||||
while running.len() < self.max_batch_size && !swapped.is_empty() {
|
||||
let slot = swapped[0].seq_slot.expect("swapped slot");
|
||||
if !self.paged_cache.can_swap_in(slot) { break; }
|
||||
if !self.paged_cache.can_swap_in(slot) {
|
||||
break;
|
||||
}
|
||||
self.paged_cache.swap_in(slot).expect("swap_in");
|
||||
let seq = swapped.remove(0);
|
||||
eprintln!("[scheduler] swapped in seq {} ({} blocks)", seq.id, self.paged_cache.block_count(slot));
|
||||
eprintln!(
|
||||
"[scheduler] swapped in seq {} ({} blocks)",
|
||||
seq.id,
|
||||
self.paged_cache.block_count(slot)
|
||||
);
|
||||
running.push(seq);
|
||||
}
|
||||
|
||||
@@ -161,14 +179,22 @@ impl Engine {
|
||||
let mut avail = self.paged_cache.free_blocks();
|
||||
let decode_reserve = running.len();
|
||||
while running.len() < self.max_batch_size {
|
||||
let Some(front) = waiting.front() else { break; };
|
||||
let Some(front) = waiting.front() else {
|
||||
break;
|
||||
};
|
||||
let prompt_blocks = front.prompt_tokens.len().div_ceil(BLOCK_SIZE).max(1);
|
||||
if avail < prompt_blocks + decode_reserve { break; }
|
||||
if avail < prompt_blocks + decode_reserve {
|
||||
break;
|
||||
}
|
||||
let free_slot = (0..self.paged_cache.max_seqs())
|
||||
.find(|&s| self.paged_cache.is_slot_free(s));
|
||||
let Some(slot) = free_slot else { break; };
|
||||
let Some(slot) = free_slot else {
|
||||
break;
|
||||
};
|
||||
let mut seq = waiting.pop_front().unwrap();
|
||||
self.paged_cache.register_sequence(slot).expect("register paged slot");
|
||||
self.paged_cache
|
||||
.register_sequence(slot)
|
||||
.expect("register paged slot");
|
||||
seq.seq_slot = Some(slot);
|
||||
running.push(seq);
|
||||
avail -= prompt_blocks; // projected free after this seq prefills
|
||||
@@ -199,7 +225,9 @@ impl Engine {
|
||||
if !seq.prefilled {
|
||||
let slot = seq.seq_slot.expect("slot");
|
||||
let logits = self.model.forward_prefill_paged(
|
||||
&seq.prompt_tokens, slot, &mut self.paged_cache,
|
||||
&seq.prompt_tokens,
|
||||
slot,
|
||||
&mut self.paged_cache,
|
||||
);
|
||||
let next = sample(&logits, &seq.sampling);
|
||||
seq.generated_tokens.push(next);
|
||||
@@ -219,13 +247,18 @@ impl Engine {
|
||||
&& !newly_prefilled.contains(&running[p].id)
|
||||
&& running[p].seq_slot.is_some()
|
||||
});
|
||||
let Some(pos) = victim else { break; };
|
||||
let Some(pos) = victim else {
|
||||
break;
|
||||
};
|
||||
let seq = running.remove(pos);
|
||||
let slot = seq.seq_slot.unwrap();
|
||||
if self.paged_cache.can_swap_out(slot) {
|
||||
let nblocks = self.paged_cache.block_count(slot);
|
||||
self.paged_cache.swap_out(slot).expect("swap_out");
|
||||
eprintln!("[scheduler] preempt: swapped out seq {} ({nblocks} blocks) to host", seq.id);
|
||||
eprintln!(
|
||||
"[scheduler] preempt: swapped out seq {} ({nblocks} blocks) to host",
|
||||
seq.id
|
||||
);
|
||||
swapped.push(seq);
|
||||
needed = decode_block_need(&self.paged_cache, &running, &newly_prefilled);
|
||||
} else {
|
||||
@@ -235,7 +268,9 @@ impl Engine {
|
||||
}
|
||||
|
||||
// Step 5c: Batched paged decode for the surviving prefilled sequences.
|
||||
let decode_indices: Vec<usize> = running.iter().enumerate()
|
||||
let decode_indices: Vec<usize> = running
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, s)| s.prefilled && !newly_prefilled.contains(&s.id))
|
||||
.map(|(i, _)| i)
|
||||
.collect();
|
||||
@@ -246,25 +281,32 @@ impl Engine {
|
||||
eprintln!("[scheduler] paged decode active");
|
||||
});
|
||||
|
||||
let tokens: Vec<u32> = decode_indices.iter()
|
||||
let tokens: Vec<u32> = decode_indices
|
||||
.iter()
|
||||
.map(|&i| *running[i].generated_tokens.last().unwrap())
|
||||
.collect();
|
||||
let positions: Vec<usize> = decode_indices.iter()
|
||||
let positions: Vec<usize> = decode_indices
|
||||
.iter()
|
||||
.map(|&i| self.paged_cache.seq_len(running[i].seq_slot.unwrap()))
|
||||
.collect();
|
||||
let slots: Vec<usize> = decode_indices.iter()
|
||||
let slots: Vec<usize> = decode_indices
|
||||
.iter()
|
||||
.map(|&i| running[i].seq_slot.unwrap())
|
||||
.collect();
|
||||
|
||||
let logits = self.model.forward_decode_paged(
|
||||
&tokens, &positions, &slots, &mut self.paged_cache,
|
||||
&tokens,
|
||||
&positions,
|
||||
&slots,
|
||||
&mut self.paged_cache,
|
||||
);
|
||||
|
||||
// Fast path: every active sequence is greedy → run argmax on
|
||||
// the GPU and only D2H the chosen token ids (a few bytes per
|
||||
// sequence) instead of the full [B, vocab_size] BF16 logits
|
||||
// (~1.2 MB for B=4, Qwen3 vocab=152K).
|
||||
let all_greedy = decode_indices.iter()
|
||||
let all_greedy = decode_indices
|
||||
.iter()
|
||||
.all(|&i| running[i].sampling.temperature == 0.0);
|
||||
if all_greedy {
|
||||
let next_ids = xserv_kernels::argmax_bf16_to_host(&logits);
|
||||
@@ -285,11 +327,15 @@ impl Engine {
|
||||
let row_start = j * vocab_size;
|
||||
let row_logits = &data[row_start..row_start + vocab_size];
|
||||
let next = if running[i].sampling.temperature == 0.0 {
|
||||
row_logits.iter().enumerate()
|
||||
row_logits
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
|
||||
.map(|(idx, _)| idx as u32).unwrap()
|
||||
.map(|(idx, _)| idx as u32)
|
||||
.unwrap()
|
||||
} else {
|
||||
let row_tensor = xserv_tensor::Tensor::from_slice(row_logits, &[1, vocab_size]);
|
||||
let row_tensor =
|
||||
xserv_tensor::Tensor::from_slice(row_logits, &[1, vocab_size]);
|
||||
sample(&row_tensor, &running[i].sampling)
|
||||
};
|
||||
running[i].generated_tokens.push(next);
|
||||
@@ -334,7 +380,8 @@ impl Engine {
|
||||
/// Total additional GPU blocks the next decode step needs across all
|
||||
/// currently-decoding (prefilled, not just-prefilled) sequences.
|
||||
fn decode_block_need(paged: &PagedKVCache, running: &[Sequence], newly_prefilled: &[u64]) -> usize {
|
||||
running.iter()
|
||||
running
|
||||
.iter()
|
||||
.filter(|s| s.prefilled && !newly_prefilled.contains(&s.id))
|
||||
.filter_map(|s| s.seq_slot)
|
||||
.map(|slot| paged.additional_blocks_needed(slot, 1))
|
||||
@@ -372,8 +419,12 @@ fn send_token_if_nonempty(seq: &Sequence, text: String) {
|
||||
}
|
||||
|
||||
fn is_finished(seq: &Sequence) -> bool {
|
||||
if seq.generated_tokens.is_empty() { return false; }
|
||||
if seq.generated_tokens.is_empty() {
|
||||
return false;
|
||||
}
|
||||
let last = *seq.generated_tokens.last().unwrap();
|
||||
if seq.generated_tokens.len() >= seq.max_tokens { return true; }
|
||||
if seq.generated_tokens.len() >= seq.max_tokens {
|
||||
return true;
|
||||
}
|
||||
seq.sender.is_closed() || seq.eos_token_id == Some(last)
|
||||
}
|
||||
|
||||
@@ -3,10 +3,13 @@ mod engine;
|
||||
mod pp_engine;
|
||||
mod tp_engine;
|
||||
|
||||
use axum::{routing::{get, post}, Extension, Router};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{mpsc, Arc, Mutex};
|
||||
use axum::{
|
||||
Extension, Router,
|
||||
routing::{get, post},
|
||||
};
|
||||
use engine::GenerateRequest;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Arc, Mutex, mpsc};
|
||||
use xserv_model::ModelConfig;
|
||||
|
||||
pub struct AppState {
|
||||
@@ -21,40 +24,48 @@ pub struct AppState {
|
||||
async fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
if args.len() < 2 {
|
||||
eprintln!("Usage: xserv-server <model-dir> [--port PORT] [--max-batch N] [--max-seq-len N] [--swap-space-gb N] [--tp N] [--pp N]");
|
||||
eprintln!(
|
||||
"Usage: xserv-server <model-dir> [--port PORT] [--max-batch N] [--max-seq-len N] [--swap-space-gb N] [--tp N] [--pp N]"
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let model_dir = PathBuf::from(&args[1]);
|
||||
let port: u16 = args.iter()
|
||||
let port: u16 = args
|
||||
.iter()
|
||||
.position(|a| a == "--port")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(8080);
|
||||
let max_batch: usize = args.iter()
|
||||
let max_batch: usize = args
|
||||
.iter()
|
||||
.position(|a| a == "--max-batch")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(4)
|
||||
.max(1);
|
||||
let requested_max_seq_len: usize = args.iter()
|
||||
let requested_max_seq_len: usize = args
|
||||
.iter()
|
||||
.position(|a| a == "--max-seq-len")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(2048)
|
||||
.max(1);
|
||||
let swap_space_gb: usize = args.iter()
|
||||
let swap_space_gb: usize = args
|
||||
.iter()
|
||||
.position(|a| a == "--swap-space-gb")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(8);
|
||||
let tp: usize = args.iter()
|
||||
let tp: usize = args
|
||||
.iter()
|
||||
.position(|a| a == "--tp")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(1)
|
||||
.max(1);
|
||||
let pp: usize = args.iter()
|
||||
let pp: usize = args
|
||||
.iter()
|
||||
.position(|a| a == "--pp")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
@@ -69,7 +80,9 @@ async fn main() {
|
||||
// tp=1 (single-rank world) so quantized models can serve on one GPU.
|
||||
let is_gpt_oss = model_config.model_type.as_deref() == Some("gpt_oss");
|
||||
if pp > 1 && is_gpt_oss {
|
||||
eprintln!("gpt-oss is not supported by the pipeline-parallel engine (Qwen3 only); use --tp instead");
|
||||
eprintln!(
|
||||
"gpt-oss is not supported by the pipeline-parallel engine (Qwen3 only); use --tp instead"
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
let model_max_seq_len = model_config.max_seq_len();
|
||||
@@ -84,7 +97,8 @@ async fn main() {
|
||||
);
|
||||
}
|
||||
|
||||
let model_name = model_dir.file_name()
|
||||
let model_name = model_dir
|
||||
.file_name()
|
||||
.map(|n| n.to_string_lossy().to_string())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
@@ -99,7 +113,12 @@ async fn main() {
|
||||
// Pipeline-parallel path: stage-0 coordinator + worker stage threads.
|
||||
pp_engine::run_pp(&model_dir_clone, pp, max_seq_len, rx);
|
||||
} else if tp <= 1 && !is_gpt_oss {
|
||||
let mut engine = engine::Engine::load_with_swap(&model_dir_clone, max_batch, max_seq_len, swap_space_gb);
|
||||
let mut engine = engine::Engine::load_with_swap(
|
||||
&model_dir_clone,
|
||||
max_batch,
|
||||
max_seq_len,
|
||||
swap_space_gb,
|
||||
);
|
||||
engine.run(rx);
|
||||
} else {
|
||||
// Tensor-parallel path: rank-0 coordinator + worker rank threads.
|
||||
|
||||
@@ -15,15 +15,15 @@
|
||||
|
||||
use std::ffi::c_void;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::mpsc;
|
||||
use std::sync::Arc;
|
||||
use std::sync::mpsc;
|
||||
use std::thread;
|
||||
|
||||
use half::bf16;
|
||||
use xserv_distributed::{PpContext, UniqueId};
|
||||
use xserv_model::loader;
|
||||
use xserv_model::sampling::SamplingParams;
|
||||
use xserv_model::{sample, ModelConfig, PagedKVCache, Qwen3, BLOCK_SIZE};
|
||||
use xserv_model::{BLOCK_SIZE, ModelConfig, PagedKVCache, Qwen3, sample};
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
@@ -38,9 +38,16 @@ enum PpCommand {
|
||||
Free(usize),
|
||||
/// Receive `[n_tokens, hidden]` from the previous stage, run this stage's
|
||||
/// layers; if last stage, sample with `sampling` and return the token.
|
||||
Prefill { n_tokens: usize, slot: usize, sampling: SamplingParams },
|
||||
Prefill {
|
||||
n_tokens: usize,
|
||||
slot: usize,
|
||||
sampling: SamplingParams,
|
||||
},
|
||||
/// Receive `[1, hidden]`, run this stage's layers; last stage samples.
|
||||
Decode { slot: usize, sampling: SamplingParams },
|
||||
Decode {
|
||||
slot: usize,
|
||||
sampling: SamplingParams,
|
||||
},
|
||||
Shutdown,
|
||||
}
|
||||
|
||||
@@ -76,9 +83,21 @@ fn build_stage(
|
||||
let max_blocks_per_seq = max_seq_len.div_ceil(BLOCK_SIZE);
|
||||
let total_blocks = max_blocks_per_seq + 8; // v1 serial: one active sequence
|
||||
let cache = PagedKVCache::new(
|
||||
&stage_config, total_blocks, 0, 4, max_blocks_per_seq, DType::BF16, device,
|
||||
&stage_config,
|
||||
total_blocks,
|
||||
0,
|
||||
4,
|
||||
max_blocks_per_seq,
|
||||
DType::BF16,
|
||||
device,
|
||||
);
|
||||
StageCtx { model, cache, pp, hidden: config.hidden(), device }
|
||||
StageCtx {
|
||||
model,
|
||||
cache,
|
||||
pp,
|
||||
hidden: config.hidden(),
|
||||
device,
|
||||
}
|
||||
}
|
||||
|
||||
/// Allocate a zeroed `[n, hidden]` device tensor and receive into it from `peer`.
|
||||
@@ -110,7 +129,15 @@ fn worker_loop(
|
||||
ack_tx: mpsc::Sender<()>,
|
||||
token_tx: mpsc::Sender<u32>,
|
||||
) {
|
||||
let mut sc = build_stage(&model_dir, &config, stage, world, stage as u32, max_seq_len, id);
|
||||
let mut sc = build_stage(
|
||||
&model_dir,
|
||||
&config,
|
||||
stage,
|
||||
world,
|
||||
stage as u32,
|
||||
max_seq_len,
|
||||
id,
|
||||
);
|
||||
let is_last = stage == world - 1;
|
||||
let prev = stage - 1;
|
||||
let next = stage + 1;
|
||||
@@ -125,7 +152,11 @@ fn worker_loop(
|
||||
sc.cache.free_sequence(slot);
|
||||
let _ = ack_tx.send(());
|
||||
}
|
||||
PpCommand::Prefill { n_tokens, slot, sampling } => {
|
||||
PpCommand::Prefill {
|
||||
n_tokens,
|
||||
slot,
|
||||
sampling,
|
||||
} => {
|
||||
let x = recv_hidden(&sc, n_tokens, prev);
|
||||
let x = sc.model.forward_layers_prefill(x, slot, &mut sc.cache);
|
||||
if is_last {
|
||||
@@ -155,7 +186,12 @@ fn worker_loop(
|
||||
|
||||
/// Run the PP coordinator (stage 0) on the calling thread. Spawns worker stages
|
||||
/// 1..world and consumes generation requests from `rx`.
|
||||
pub fn run_pp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Receiver<GenerateRequest>) {
|
||||
pub fn run_pp(
|
||||
model_dir: &Path,
|
||||
world: usize,
|
||||
max_seq_len: usize,
|
||||
rx: mpsc::Receiver<GenerateRequest>,
|
||||
) {
|
||||
assert!(world >= 2, "run_pp requires world >= 2");
|
||||
let config = ModelConfig::from_file(&model_dir.join("config.json"));
|
||||
assert!(
|
||||
@@ -179,7 +215,17 @@ pub fn run_pp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
|
||||
let model_dir = model_dir.to_path_buf();
|
||||
let config = config.clone();
|
||||
thread::spawn(move || {
|
||||
worker_loop(stage, world, id, model_dir, config, max_seq_len, ctx_rx, ack_tx, token_tx);
|
||||
worker_loop(
|
||||
stage,
|
||||
world,
|
||||
id,
|
||||
model_dir,
|
||||
config,
|
||||
max_seq_len,
|
||||
ctx_rx,
|
||||
ack_tx,
|
||||
token_tx,
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -207,11 +253,14 @@ pub fn run_pp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
|
||||
wait_acks(&ack_rx);
|
||||
|
||||
// Prefill: embed prompt, run stage-0 layers, push hidden into the pipe.
|
||||
broadcast(&cmd_txs, PpCommand::Prefill {
|
||||
broadcast(
|
||||
&cmd_txs,
|
||||
PpCommand::Prefill {
|
||||
n_tokens: req.prompt_tokens.len(),
|
||||
slot,
|
||||
sampling: req.sampling.clone(),
|
||||
});
|
||||
},
|
||||
);
|
||||
let x = sc.model.embed(&req.prompt_tokens);
|
||||
let x = sc.model.forward_layers_prefill(x, slot, &mut sc.cache);
|
||||
send_hidden(&sc, &x, next_peer);
|
||||
@@ -228,7 +277,13 @@ pub fn run_pp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
|
||||
if generated >= req.max_tokens {
|
||||
break "length";
|
||||
}
|
||||
broadcast(&cmd_txs, PpCommand::Decode { slot, sampling: req.sampling.clone() });
|
||||
broadcast(
|
||||
&cmd_txs,
|
||||
PpCommand::Decode {
|
||||
slot,
|
||||
sampling: req.sampling.clone(),
|
||||
},
|
||||
);
|
||||
let x = sc.model.embed(&[next]);
|
||||
let x = sc.model.forward_layers_decode(x, &[slot], &mut sc.cache);
|
||||
send_hidden(&sc, &x, next_peer);
|
||||
@@ -239,9 +294,14 @@ pub fn run_pp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
|
||||
|
||||
let tail = tokenizer.flush_decode_stream(&mut decode_buf);
|
||||
if !tail.is_empty() {
|
||||
let _ = req.sender.blocking_send(GenerateEvent::Token { id: next, text: tail });
|
||||
let _ = req.sender.blocking_send(GenerateEvent::Token {
|
||||
id: next,
|
||||
text: tail,
|
||||
});
|
||||
}
|
||||
let _ = req.sender.blocking_send(GenerateEvent::Done { finish_reason: finish.to_string() });
|
||||
let _ = req.sender.blocking_send(GenerateEvent::Done {
|
||||
finish_reason: finish.to_string(),
|
||||
});
|
||||
|
||||
broadcast(&cmd_txs, PpCommand::Free(slot));
|
||||
sc.cache.free_sequence(slot);
|
||||
@@ -258,6 +318,8 @@ fn emit_text(tokenizer: &Tokenizer, req: &GenerateRequest, token_id: u32, buf: &
|
||||
}
|
||||
let text = tokenizer.decode_token_stream(token_id, buf);
|
||||
if !text.is_empty() {
|
||||
let _ = req.sender.blocking_send(GenerateEvent::Token { id: token_id, text });
|
||||
let _ = req
|
||||
.sender
|
||||
.blocking_send(GenerateEvent::Token { id: token_id, text });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,13 +13,16 @@
|
||||
//! work; the single-GPU `Engine` still handles TP=1.
|
||||
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::mpsc;
|
||||
use std::sync::Arc;
|
||||
use std::sync::mpsc;
|
||||
use std::thread;
|
||||
|
||||
use xserv_distributed::{TpContext, UniqueId};
|
||||
use xserv_model::loader;
|
||||
use xserv_model::{sample, sample_greedy_penalized, GptOss, GraphedGptOssDecoder, ModelConfig, PagedKVCache, Qwen3, BLOCK_SIZE};
|
||||
use xserv_model::{
|
||||
BLOCK_SIZE, GptOss, GraphedGptOssDecoder, ModelConfig, PagedKVCache, Qwen3, sample,
|
||||
sample_greedy_penalized,
|
||||
};
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
@@ -29,8 +32,15 @@ use crate::engine::{GenerateEvent, GenerateRequest};
|
||||
enum TpCommand {
|
||||
Register(usize),
|
||||
Free(usize),
|
||||
Prefill { tokens: Vec<u32>, slot: usize },
|
||||
Decode { tokens: Vec<u32>, positions: Vec<usize>, slots: Vec<usize> },
|
||||
Prefill {
|
||||
tokens: Vec<u32>,
|
||||
slot: usize,
|
||||
},
|
||||
Decode {
|
||||
tokens: Vec<u32>,
|
||||
positions: Vec<usize>,
|
||||
slots: Vec<usize>,
|
||||
},
|
||||
Shutdown,
|
||||
}
|
||||
|
||||
@@ -40,14 +50,25 @@ enum TpModel {
|
||||
}
|
||||
|
||||
impl TpModel {
|
||||
fn forward_prefill_paged(&self, tokens: &[u32], slot: usize, cache: &mut PagedKVCache) -> Tensor {
|
||||
fn forward_prefill_paged(
|
||||
&self,
|
||||
tokens: &[u32],
|
||||
slot: usize,
|
||||
cache: &mut PagedKVCache,
|
||||
) -> Tensor {
|
||||
match self {
|
||||
TpModel::Qwen3(m) => m.forward_prefill_paged(tokens, slot, cache),
|
||||
TpModel::GptOss(m) => m.forward_prefill_paged(tokens, slot, cache),
|
||||
}
|
||||
}
|
||||
|
||||
fn forward_decode_paged(&self, tokens: &[u32], positions: &[usize], slots: &[usize], cache: &mut PagedKVCache) -> Tensor {
|
||||
fn forward_decode_paged(
|
||||
&self,
|
||||
tokens: &[u32],
|
||||
positions: &[usize],
|
||||
slots: &[usize],
|
||||
cache: &mut PagedKVCache,
|
||||
) -> Tensor {
|
||||
match self {
|
||||
TpModel::Qwen3(m) => m.forward_decode_paged(tokens, positions, slots, cache),
|
||||
TpModel::GptOss(m) => m.forward_decode_paged(tokens, positions, slots, cache),
|
||||
@@ -65,8 +86,12 @@ struct RankCtx {
|
||||
/// (lazy capture, replay thereafter); everything else runs eager.
|
||||
fn rank_decode(rc: &mut RankCtx, tokens: &[u32], positions: &[usize], slots: &[usize]) -> Tensor {
|
||||
match &rc.model {
|
||||
TpModel::GptOss(m) => rc.decoder.decode(m, tokens, positions, slots, &mut rc.cache),
|
||||
TpModel::Qwen3(_) => rc.model.forward_decode_paged(tokens, positions, slots, &mut rc.cache),
|
||||
TpModel::GptOss(m) => rc
|
||||
.decoder
|
||||
.decode(m, tokens, positions, slots, &mut rc.cache),
|
||||
TpModel::Qwen3(_) => rc
|
||||
.model
|
||||
.forward_decode_paged(tokens, positions, slots, &mut rc.cache),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,17 +106,42 @@ fn build_rank(
|
||||
) -> RankCtx {
|
||||
let weights = loader::load_model_dir(model_dir, Device::Cpu);
|
||||
let model = if config.is_moe() {
|
||||
TpModel::GptOss(GptOss::from_weights_tp(config.clone(), weights, rank, world, device, tp))
|
||||
TpModel::GptOss(GptOss::from_weights_tp(
|
||||
config.clone(),
|
||||
weights,
|
||||
rank,
|
||||
world,
|
||||
device,
|
||||
tp,
|
||||
))
|
||||
} else {
|
||||
TpModel::Qwen3(Qwen3::from_weights_tp(config.clone(), weights, rank, world, device, tp))
|
||||
TpModel::Qwen3(Qwen3::from_weights_tp(
|
||||
config.clone(),
|
||||
weights,
|
||||
rank,
|
||||
world,
|
||||
device,
|
||||
tp,
|
||||
))
|
||||
};
|
||||
let local_kv = config.num_kv_heads() / world;
|
||||
let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
let total_blocks = max_blocks_per_seq + 8;
|
||||
let cache = PagedKVCache::new_tp(
|
||||
config, local_kv, total_blocks, 0, 4, max_blocks_per_seq, DType::BF16, device,
|
||||
config,
|
||||
local_kv,
|
||||
total_blocks,
|
||||
0,
|
||||
4,
|
||||
max_blocks_per_seq,
|
||||
DType::BF16,
|
||||
device,
|
||||
);
|
||||
RankCtx { model, cache, decoder: GraphedGptOssDecoder::new() }
|
||||
RankCtx {
|
||||
model,
|
||||
cache,
|
||||
decoder: GraphedGptOssDecoder::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn worker_loop(
|
||||
@@ -105,7 +155,15 @@ fn worker_loop(
|
||||
ack_tx: mpsc::Sender<()>,
|
||||
) {
|
||||
let tp = Arc::new(TpContext::init(rank, world, id, rank as u32));
|
||||
let mut rc = build_rank(&model_dir, &config, rank, world, rank as u32, max_seq_len, Some(tp));
|
||||
let mut rc = build_rank(
|
||||
&model_dir,
|
||||
&config,
|
||||
rank,
|
||||
world,
|
||||
rank as u32,
|
||||
max_seq_len,
|
||||
Some(tp),
|
||||
);
|
||||
while let Ok(cmd) = cmd_rx.recv() {
|
||||
match cmd {
|
||||
TpCommand::Register(slot) => {
|
||||
@@ -115,7 +173,11 @@ fn worker_loop(
|
||||
TpCommand::Prefill { tokens, slot } => {
|
||||
let _ = rc.model.forward_prefill_paged(&tokens, slot, &mut rc.cache);
|
||||
}
|
||||
TpCommand::Decode { tokens, positions, slots } => {
|
||||
TpCommand::Decode {
|
||||
tokens,
|
||||
positions,
|
||||
slots,
|
||||
} => {
|
||||
let _ = rank_decode(&mut rc, &tokens, &positions, &slots);
|
||||
}
|
||||
TpCommand::Shutdown => {
|
||||
@@ -129,7 +191,12 @@ fn worker_loop(
|
||||
|
||||
/// Run the TP coordinator (rank 0) on the calling thread. Spawns worker ranks
|
||||
/// internally and consumes generation requests from `rx`.
|
||||
pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Receiver<GenerateRequest>) {
|
||||
pub fn run_tp(
|
||||
model_dir: &Path,
|
||||
world: usize,
|
||||
max_seq_len: usize,
|
||||
rx: mpsc::Receiver<GenerateRequest>,
|
||||
) {
|
||||
// world=1 is a valid single-rank configuration (gpt-oss has no
|
||||
// single-GPU engine path; NCCL init and all_reduce no-op at world=1).
|
||||
assert!(world >= 1, "run_tp requires world >= 1");
|
||||
@@ -152,7 +219,16 @@ pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
|
||||
let model_dir = model_dir.to_path_buf();
|
||||
let config = config.clone();
|
||||
thread::spawn(move || {
|
||||
worker_loop(rank, world, id, model_dir, config, max_seq_len, ctx_rx, ack_tx);
|
||||
worker_loop(
|
||||
rank,
|
||||
world,
|
||||
id,
|
||||
model_dir,
|
||||
config,
|
||||
max_seq_len,
|
||||
ctx_rx,
|
||||
ack_tx,
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -165,10 +241,14 @@ pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
|
||||
// models loop under pure greedy when numerics diverge from the reference).
|
||||
// Off by default; XSERV_REP_PENALTY>1 enables it over the last
|
||||
// XSERV_REP_WINDOW generated tokens. Applied only on the greedy path.
|
||||
let rep_penalty: f32 = std::env::var("XSERV_REP_PENALTY").ok()
|
||||
.and_then(|s| s.parse().ok()).unwrap_or(1.0);
|
||||
let rep_window: usize = std::env::var("XSERV_REP_WINDOW").ok()
|
||||
.and_then(|s| s.parse().ok()).unwrap_or(128);
|
||||
let rep_penalty: f32 = std::env::var("XSERV_REP_PENALTY")
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(1.0);
|
||||
let rep_window: usize = std::env::var("XSERV_REP_WINDOW")
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(128);
|
||||
let pick = |logits: &Tensor, sp: &xserv_model::SamplingParams, history: &[u32]| -> u32 {
|
||||
if rep_penalty > 1.0 && sp.temperature == 0.0 {
|
||||
let start = history.len().saturating_sub(rep_window);
|
||||
@@ -197,8 +277,16 @@ pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
|
||||
wait_acks(&ack_rx);
|
||||
|
||||
// Prefill.
|
||||
broadcast(&cmd_txs, TpCommand::Prefill { tokens: req.prompt_tokens.clone(), slot });
|
||||
let logits = rc.model.forward_prefill_paged(&req.prompt_tokens, slot, &mut rc.cache);
|
||||
broadcast(
|
||||
&cmd_txs,
|
||||
TpCommand::Prefill {
|
||||
tokens: req.prompt_tokens.clone(),
|
||||
slot,
|
||||
},
|
||||
);
|
||||
let logits = rc
|
||||
.model
|
||||
.forward_prefill_paged(&req.prompt_tokens, slot, &mut rc.cache);
|
||||
wait_acks(&ack_rx);
|
||||
let mut gen_ids: Vec<u32> = Vec::new();
|
||||
let mut next = pick(&logits, &req.sampling, &gen_ids);
|
||||
@@ -216,7 +304,14 @@ pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
|
||||
break "length";
|
||||
}
|
||||
let pos = rc.cache.seq_len(slot);
|
||||
broadcast(&cmd_txs, TpCommand::Decode { tokens: vec![next], positions: vec![pos], slots: vec![slot] });
|
||||
broadcast(
|
||||
&cmd_txs,
|
||||
TpCommand::Decode {
|
||||
tokens: vec![next],
|
||||
positions: vec![pos],
|
||||
slots: vec![slot],
|
||||
},
|
||||
);
|
||||
let logits = rank_decode(&mut rc, &[next], &[pos], &[slot]);
|
||||
wait_acks(&ack_rx);
|
||||
next = pick(&logits, &req.sampling, &gen_ids);
|
||||
@@ -227,9 +322,14 @@ pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
|
||||
|
||||
let tail = tokenizer.flush_decode_stream(&mut decode_buf);
|
||||
if !tail.is_empty() {
|
||||
let _ = req.sender.blocking_send(GenerateEvent::Token { id: next, text: tail });
|
||||
let _ = req.sender.blocking_send(GenerateEvent::Token {
|
||||
id: next,
|
||||
text: tail,
|
||||
});
|
||||
}
|
||||
let _ = req.sender.blocking_send(GenerateEvent::Done { finish_reason: finish.to_string() });
|
||||
let _ = req.sender.blocking_send(GenerateEvent::Done {
|
||||
finish_reason: finish.to_string(),
|
||||
});
|
||||
|
||||
broadcast(&cmd_txs, TpCommand::Free(slot));
|
||||
rc.cache.free_sequence(slot);
|
||||
@@ -246,6 +346,8 @@ fn emit_text(tokenizer: &Tokenizer, req: &GenerateRequest, token_id: u32, buf: &
|
||||
}
|
||||
let text = tokenizer.decode_token_stream(token_id, buf);
|
||||
if !text.is_empty() {
|
||||
let _ = req.sender.blocking_send(GenerateEvent::Token { id: token_id, text });
|
||||
let _ = req
|
||||
.sender
|
||||
.blocking_send(GenerateEvent::Token { id: token_id, text });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,18 +43,30 @@ pub trait TensorDType: Copy + Send + Sync + 'static {
|
||||
|
||||
impl TensorDType for f32 {
|
||||
const DTYPE: DType = DType::F32;
|
||||
fn to_f64(self) -> f64 { self as f64 }
|
||||
fn from_f64(v: f64) -> Self { v as f32 }
|
||||
fn to_f64(self) -> f64 {
|
||||
self as f64
|
||||
}
|
||||
fn from_f64(v: f64) -> Self {
|
||||
v as f32
|
||||
}
|
||||
}
|
||||
|
||||
impl TensorDType for f16 {
|
||||
const DTYPE: DType = DType::F16;
|
||||
fn to_f64(self) -> f64 { self.to_f32() as f64 }
|
||||
fn from_f64(v: f64) -> Self { f16::from_f32(v as f32) }
|
||||
fn to_f64(self) -> f64 {
|
||||
self.to_f32() as f64
|
||||
}
|
||||
fn from_f64(v: f64) -> Self {
|
||||
f16::from_f32(v as f32)
|
||||
}
|
||||
}
|
||||
|
||||
impl TensorDType for bf16 {
|
||||
const DTYPE: DType = DType::BF16;
|
||||
fn to_f64(self) -> f64 { self.to_f32() as f64 }
|
||||
fn from_f64(v: f64) -> Self { bf16::from_f32(v as f32) }
|
||||
fn to_f64(self) -> f64 {
|
||||
self.to_f32() as f64
|
||||
}
|
||||
fn from_f64(v: f64) -> Self {
|
||||
bf16::from_f32(v as f32)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,4 +6,4 @@ pub mod tensor;
|
||||
pub use dtype::{DType, TensorDType};
|
||||
pub use shape::Dims;
|
||||
pub use storage::{Device, Storage};
|
||||
pub use tensor::{register_gpu_contiguous, Tensor};
|
||||
pub use tensor::{Tensor, register_gpu_contiguous};
|
||||
|
||||
@@ -46,8 +46,16 @@ pub fn broadcast_shape(a: &[usize], b: &[usize]) -> Option<Dims> {
|
||||
let ndim = a.len().max(b.len());
|
||||
let mut result = SmallVec::with_capacity(ndim);
|
||||
for i in 0..ndim {
|
||||
let da = if i < ndim - a.len() { 1 } else { a[i - (ndim - a.len())] };
|
||||
let db = if i < ndim - b.len() { 1 } else { b[i - (ndim - b.len())] };
|
||||
let da = if i < ndim - a.len() {
|
||||
1
|
||||
} else {
|
||||
a[i - (ndim - a.len())]
|
||||
};
|
||||
let db = if i < ndim - b.len() {
|
||||
1
|
||||
} else {
|
||||
b[i - (ndim - b.len())]
|
||||
};
|
||||
if da == db {
|
||||
result.push(da);
|
||||
} else if da == 1 {
|
||||
@@ -100,8 +108,14 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_broadcast_shape() {
|
||||
assert_eq!(broadcast_shape(&[3, 1], &[1, 4]).unwrap().as_slice(), &[3, 4]);
|
||||
assert_eq!(broadcast_shape(&[2, 3, 4], &[4]).unwrap().as_slice(), &[2, 3, 4]);
|
||||
assert_eq!(
|
||||
broadcast_shape(&[3, 1], &[1, 4]).unwrap().as_slice(),
|
||||
&[3, 4]
|
||||
);
|
||||
assert_eq!(
|
||||
broadcast_shape(&[2, 3, 4], &[4]).unwrap().as_slice(),
|
||||
&[2, 3, 4]
|
||||
);
|
||||
assert_eq!(broadcast_shape(&[1], &[5, 3]).unwrap().as_slice(), &[5, 3]);
|
||||
assert!(broadcast_shape(&[3], &[4]).is_none());
|
||||
}
|
||||
@@ -109,6 +123,9 @@ mod tests {
|
||||
#[test]
|
||||
fn test_broadcast_strides() {
|
||||
// [3,1] with strides [1,1] broadcast to [3,4]
|
||||
assert_eq!(broadcast_strides(&[3, 1], &[1, 1], &[3, 4]).as_slice(), &[1, 0]);
|
||||
assert_eq!(
|
||||
broadcast_strides(&[3, 1], &[1, 1], &[3, 4]).as_slice(),
|
||||
&[1, 0]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,8 +33,20 @@ impl Tensor {
|
||||
// --- Creation ---
|
||||
|
||||
/// Create a tensor from raw components (for advanced use like GPU KV cache).
|
||||
pub fn from_storage(storage: Storage, shape: Dims, strides: Dims, offset: usize, dtype: DType) -> Self {
|
||||
Self { storage, shape, strides, offset, dtype }
|
||||
pub fn from_storage(
|
||||
storage: Storage,
|
||||
shape: Dims,
|
||||
strides: Dims,
|
||||
offset: usize,
|
||||
dtype: DType,
|
||||
) -> Self {
|
||||
Self {
|
||||
storage,
|
||||
shape,
|
||||
strides,
|
||||
offset,
|
||||
dtype,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_slice<T: TensorDType>(data: &[T], shape: &[usize]) -> Self {
|
||||
@@ -60,7 +72,10 @@ impl Tensor {
|
||||
data.len(),
|
||||
numel * dtype.size_bytes(),
|
||||
"raw bytes length {} != expected {} (numel={} * elem_size={})",
|
||||
data.len(), numel * dtype.size_bytes(), numel, dtype.size_bytes()
|
||||
data.len(),
|
||||
numel * dtype.size_bytes(),
|
||||
numel,
|
||||
dtype.size_bytes()
|
||||
);
|
||||
Self {
|
||||
storage: Storage::cpu(data.to_vec()),
|
||||
@@ -112,14 +127,28 @@ impl Tensor {
|
||||
|
||||
// --- Properties ---
|
||||
|
||||
pub fn shape(&self) -> &[usize] { &self.shape }
|
||||
pub fn strides(&self) -> &[usize] { &self.strides }
|
||||
pub fn dtype(&self) -> DType { self.dtype }
|
||||
pub fn ndim(&self) -> usize { self.shape.len() }
|
||||
pub fn numel(&self) -> usize { shape::num_elements(&self.shape) }
|
||||
pub fn offset(&self) -> usize { self.offset }
|
||||
pub fn shape(&self) -> &[usize] {
|
||||
&self.shape
|
||||
}
|
||||
pub fn strides(&self) -> &[usize] {
|
||||
&self.strides
|
||||
}
|
||||
pub fn dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
pub fn ndim(&self) -> usize {
|
||||
self.shape.len()
|
||||
}
|
||||
pub fn numel(&self) -> usize {
|
||||
shape::num_elements(&self.shape)
|
||||
}
|
||||
pub fn offset(&self) -> usize {
|
||||
self.offset
|
||||
}
|
||||
|
||||
pub fn device(&self) -> Device { self.storage.device() }
|
||||
pub fn device(&self) -> Device {
|
||||
self.storage.device()
|
||||
}
|
||||
|
||||
pub fn is_contiguous(&self) -> bool {
|
||||
shape::is_contiguous(&self.shape, &self.strides)
|
||||
@@ -193,7 +222,11 @@ impl Tensor {
|
||||
shape::contiguous_strides(&new_shape)
|
||||
} else {
|
||||
let mut s = self.strides.clone();
|
||||
let stride_val = if dim < self.strides.len() { self.strides[dim] } else { 1 };
|
||||
let stride_val = if dim < self.strides.len() {
|
||||
self.strides[dim]
|
||||
} else {
|
||||
1
|
||||
};
|
||||
s.insert(dim, stride_val);
|
||||
s
|
||||
};
|
||||
@@ -230,7 +263,12 @@ impl Tensor {
|
||||
let ndim = self.ndim();
|
||||
let mut idx = vec![0usize; ndim];
|
||||
for flat in 0..numel {
|
||||
let src_offset = self.offset + idx.iter().zip(self.strides.iter()).map(|(i, s)| i * s).sum::<usize>();
|
||||
let src_offset = self.offset
|
||||
+ idx
|
||||
.iter()
|
||||
.zip(self.strides.iter())
|
||||
.map(|(i, s)| i * s)
|
||||
.sum::<usize>();
|
||||
let src_byte_offset = src_offset * elem_size;
|
||||
let dst_byte_offset = flat * elem_size;
|
||||
dst[dst_byte_offset..dst_byte_offset + elem_size]
|
||||
@@ -261,7 +299,10 @@ impl Tensor {
|
||||
}
|
||||
// Transfer the raw storage (preserving strides/offset).
|
||||
// Non-contiguous layout is preserved — the user can call contiguous() after.
|
||||
let new_storage = self.storage.to_device(device).expect("device transfer failed");
|
||||
let new_storage = self
|
||||
.storage
|
||||
.to_device(device)
|
||||
.expect("device transfer failed");
|
||||
Self {
|
||||
storage: new_storage,
|
||||
shape: self.shape.clone(),
|
||||
@@ -310,14 +351,20 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn storage(&self) -> &Storage { &self.storage }
|
||||
pub fn storage(&self) -> &Storage {
|
||||
&self.storage
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Tensor {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f, "Tensor(shape={:?}, dtype={}, device={}, contiguous={})",
|
||||
self.shape.as_slice(), self.dtype, self.device(), self.is_contiguous()
|
||||
f,
|
||||
"Tensor(shape={:?}, dtype={}, device={}, contiguous={})",
|
||||
self.shape.as_slice(),
|
||||
self.dtype,
|
||||
self.device(),
|
||||
self.is_contiguous()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,7 +32,11 @@ fn test_zeros_and_ones() {
|
||||
|
||||
#[test]
|
||||
fn test_bf16_tensor() {
|
||||
let data: Vec<bf16> = vec![bf16::from_f32(1.0), bf16::from_f32(2.5), bf16::from_f32(-3.0)];
|
||||
let data: Vec<bf16> = vec![
|
||||
bf16::from_f32(1.0),
|
||||
bf16::from_f32(2.5),
|
||||
bf16::from_f32(-3.0),
|
||||
];
|
||||
let t = Tensor::from_slice(&data, &[3]);
|
||||
assert_eq!(t.dtype(), DType::BF16);
|
||||
let out = t.as_slice::<bf16>();
|
||||
|
||||
@@ -95,11 +95,15 @@ impl Tokenizer {
|
||||
let (a_str, b_str) = match entry {
|
||||
MergeEntry::Str(s) => {
|
||||
let parts: Vec<&str> = s.splitn(2, ' ').collect();
|
||||
if parts.len() != 2 { continue; }
|
||||
if parts.len() != 2 {
|
||||
continue;
|
||||
}
|
||||
(parts[0].to_string(), parts[1].to_string())
|
||||
}
|
||||
MergeEntry::Pair(v) => {
|
||||
if v.len() != 2 { continue; }
|
||||
if v.len() != 2 {
|
||||
continue;
|
||||
}
|
||||
(v[0].clone(), v[1].clone())
|
||||
}
|
||||
};
|
||||
@@ -174,7 +178,10 @@ impl Tokenizer {
|
||||
if byte_fallback {
|
||||
Regex::new(r"[\p{L}\p{N}]+|[^\s\p{L}\p{N}]|\s+").unwrap()
|
||||
} else {
|
||||
Regex::new(r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+").unwrap()
|
||||
Regex::new(
|
||||
r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+",
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -262,7 +269,9 @@ impl Tokenizer {
|
||||
|
||||
// BPE merges
|
||||
loop {
|
||||
if token_ids.len() < 2 { break; }
|
||||
if token_ids.len() < 2 {
|
||||
break;
|
||||
}
|
||||
let mut best_rank = usize::MAX;
|
||||
let mut best_idx = 0;
|
||||
for i in 0..token_ids.len() - 1 {
|
||||
@@ -273,12 +282,15 @@ impl Tokenizer {
|
||||
}
|
||||
}
|
||||
}
|
||||
if best_rank == usize::MAX { break; }
|
||||
if best_rank == usize::MAX {
|
||||
break;
|
||||
}
|
||||
|
||||
let merged_bytes = [
|
||||
self.decoder[token_ids[best_idx] as usize].as_slice(),
|
||||
self.decoder[token_ids[best_idx + 1] as usize].as_slice(),
|
||||
].concat();
|
||||
]
|
||||
.concat();
|
||||
let merged_id = *self.encoder.get(&merged_bytes).unwrap_or_else(|| {
|
||||
panic!("merged token not in vocab");
|
||||
});
|
||||
@@ -389,14 +401,13 @@ fn unicode_to_byte(c: char) -> u8 {
|
||||
m
|
||||
});
|
||||
|
||||
*map.get(&(c as u32)).unwrap_or_else(|| {
|
||||
panic!("unmapped unicode char U+{:04X} in tokenizer", c as u32)
|
||||
})
|
||||
*map.get(&(c as u32))
|
||||
.unwrap_or_else(|| panic!("unmapped unicode char U+{:04X} in tokenizer", c as u32))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{take_valid_utf8, Tokenizer};
|
||||
use super::{Tokenizer, take_valid_utf8};
|
||||
|
||||
#[test]
|
||||
fn qwen_added_tokens_are_indivisible_and_im_end_is_eos() {
|
||||
|
||||
Reference in New Issue
Block a user