phase 0+1: fix Rust 2024 edition compat + memory query

- unsafe extern "C" blocks (Rust 2024 requirement)
- unsafe blocks inside unsafe fn bodies
- Use cudaMemGetInfo for accurate GPU memory reporting
- Remove cc "cuda" feature (doesn't exist, built-in)
- All 12 tests pass on RTX 5090 (CC 12.0, 170 SMs, 32GB)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-21 19:40:49 +08:00
parent 9806b4db35
commit c8f7bc0c3c
5 changed files with 40 additions and 24 deletions

View File

@@ -4,7 +4,7 @@ version.workspace = true
edition.workspace = true
[build-dependencies]
cc = { version = "1", features = ["cuda"] }
cc = "1"
[dev-dependencies]
rand = "0.9"

View File

@@ -7,6 +7,7 @@ pub struct DeviceInfo {
pub index: u32,
pub name: String,
pub total_memory: usize,
pub free_memory: usize,
pub compute_major: i32,
pub compute_minor: i32,
pub sm_count: i32,
@@ -15,8 +16,9 @@ pub struct DeviceInfo {
pub max_threads_per_block: i32,
}
extern "C" {
unsafe extern "C" {
fn cudaDeviceGetAttribute(value: *mut i32, attr: i32, device: i32) -> i32;
fn cudaMemGetInfo(free: *mut usize, total: *mut usize) -> i32;
}
fn get_attr(attr: i32, device: u32) -> Result<i32> {
@@ -42,16 +44,24 @@ pub fn current_device() -> Result<u32> {
}
pub fn device_info(device: u32) -> Result<DeviceInfo> {
// Use cudaGetDeviceProperties only for the name (first field, always stable).
// Get device name from cudaGetDeviceProperties (only use the name field).
let mut prop = unsafe { std::mem::zeroed::<ffi::CudaDeviceProp>() };
error::check(unsafe { ffi::cudaGetDeviceProperties(&mut prop, device as i32) })?;
let name = unsafe { CStr::from_ptr(prop.name.as_ptr()) }
.to_string_lossy()
.into_owned();
// Use cudaDeviceGetAttribute for everything else (layout-independent).
// Attribute IDs from cuda_runtime_api.h:
const TOTAL_GLOBAL_MEM: i32 = 0; // not available via attribute, use prop
// Get memory info via cudaMemGetInfo (layout-independent).
let prev = current_device()?;
set_device(device)?;
let mut free = 0usize;
let mut total = 0usize;
error::check(unsafe { cudaMemGetInfo(&mut free, &mut total) })?;
if prev != device {
set_device(prev)?;
}
// Attribute IDs from cuda_runtime_api.h
const SHARED_MEM_PER_BLOCK: i32 = 8;
const WARP_SIZE: i32 = 10;
const MAX_THREADS_PER_BLOCK: i32 = 1;
@@ -62,7 +72,8 @@ pub fn device_info(device: u32) -> Result<DeviceInfo> {
Ok(DeviceInfo {
index: device,
name,
total_memory: prop.total_global_mem,
total_memory: total,
free_memory: free,
compute_major: get_attr(COMPUTE_MAJOR, device)?,
compute_minor: get_attr(COMPUTE_MINOR, device)?,
sm_count: get_attr(MULTI_PROCESSOR_COUNT, device)?,

View File

@@ -30,7 +30,7 @@ pub struct CudaDeviceProp {
_pad: [u8; 4096],
}
extern "C" {
unsafe extern "C" {
// --- Device ---
pub fn cudaGetDeviceCount(count: *mut i32) -> i32;
pub fn cudaSetDevice(device: i32) -> i32;

View File

@@ -48,26 +48,30 @@ impl GpuBuffer {
/// Safety: `src` must remain valid until the stream operation completes.
pub unsafe fn copy_from_host_async(&mut self, src: &[u8], stream: &CudaStream) -> Result<()> {
assert!(src.len() <= self.len);
error::check(ffi::cudaMemcpyAsync(
self.ptr,
src.as_ptr(),
src.len(),
ffi::CUDA_MEMCPY_H2D,
stream.as_raw(),
))
unsafe {
error::check(ffi::cudaMemcpyAsync(
self.ptr,
src.as_ptr(),
src.len(),
ffi::CUDA_MEMCPY_H2D,
stream.as_raw(),
))
}
}
/// Async copy from device to host on the given stream.
/// Safety: `dst` must remain valid until the stream operation completes.
pub unsafe fn copy_to_host_async(&self, dst: &mut [u8], stream: &CudaStream) -> Result<()> {
assert!(dst.len() <= self.len);
error::check(ffi::cudaMemcpyAsync(
dst.as_mut_ptr(),
self.ptr,
dst.len(),
ffi::CUDA_MEMCPY_D2H,
stream.as_raw(),
))
unsafe {
error::check(ffi::cudaMemcpyAsync(
dst.as_mut_ptr(),
self.ptr,
dst.len(),
ffi::CUDA_MEMCPY_D2H,
stream.as_raw(),
))
}
}
/// Copy from another GPU buffer (D2D).

View File

@@ -7,7 +7,8 @@ fn test_device_info() {
let info = device::device_info(0).expect("failed to get device info");
println!("GPU 0: {}", info.name);
println!(" Memory: {} MB", info.total_memory / (1024 * 1024));
println!(" Total Memory: {} MB", info.total_memory / (1024 * 1024));
println!(" Free Memory: {} MB", info.free_memory / (1024 * 1024));
println!(
" Compute Capability: {}.{}",
info.compute_major, info.compute_minor
@@ -17,7 +18,7 @@ fn test_device_info() {
println!(" Warp Size: {}", info.warp_size);
println!(" Max Threads/Block: {}", info.max_threads_per_block);
assert!(info.total_memory > 0);
assert!(info.total_memory > 30 * 1024 * 1024 * 1024); // 5090 has 32GB
assert!(info.sm_count > 0);
}