// GPU integration tests for the tensor abstraction. Both require nvcc + a GPU, // so they are gated behind `not(no_cuda)`. On a GPU-less machine build.rs sets // the `no_cuda` cfg and these compile out, keeping host `cargo check` green. #![cfg(not(no_cuda))] use xtrain_cuda::device; use xtrain_tensor::{Device, Tensor}; /// (a) Host → device → host roundtrip preserves the data exactly. #[test] fn host_device_roundtrip() { assert!( device::device_count().expect("device count") > 0, "no CUDA device" ); device::set_device(0).unwrap(); let host: Vec = (0..1024).map(|i| i as f32 * 0.5).collect(); let cpu = Tensor::from_slice(&host, &[1024]); let gpu = cpu.to_device(Device::Cuda(0)); assert_eq!(gpu.device(), Device::Cuda(0)); assert_eq!(gpu.shape(), &[1024]); let back = gpu.to_device(Device::Cpu); assert_eq!(back.device(), Device::Cpu); assert_eq!(back.as_slice::(), host.as_slice()); println!("roundtrip OK: {} elems preserved", host.len()); } /// (b) The elementwise `scale` kernel produces correct results. #[test] fn elementwise_scale_kernel() { assert!( device::device_count().expect("device count") > 0, "no CUDA device" ); device::set_device(0).unwrap(); let host: Vec = (0..2048).map(|i| i as f32).collect(); let alpha = 3.0f32; let expected: Vec = host.iter().map(|x| x * alpha).collect(); let gpu = Tensor::from_slice(&host, &[2048]).to_device(Device::Cuda(0)); let scaled = gpu.scale(alpha); let result = scaled.to_device(Device::Cpu); assert_eq!(result.shape(), &[2048]); assert_eq!(result.as_slice::(), expected.as_slice()); let r = result.as_slice::(); println!( "scale OK (alpha={alpha}): first={} mid={} last={} ({} elems)", r[0], r[r.len() / 2], r[r.len() - 1], r.len() ); } /// (c) `rope_at` (KV-cache decode RoPE at an absolute position) is bit-identical /// to the full-sequence `rope`'s corresponding row. This is the invariant the /// decode KV-cache relies on: a single new token RoPE'd at position `t` must equal /// what the full-sequence forward would have produced at row `t` (so cached /// post-RoPE K matches the full-recompute path → token-identical decode). #[test] fn rope_at_matches_full_rope_row() { assert!( device::device_count().expect("device count") > 0, "no CUDA device" ); device::set_device(0).unwrap(); let (n, heads, hd) = (7usize, 3usize, 8usize); let theta = 10000.0f32; // Deterministic pseudo-random fill in [-1, 1). let host: Vec = (0..n * heads * hd) .map(|i| ((i * 37 % 101) as f32 / 50.0) - 1.0) .collect(); // Full-sequence rope (period = n → row r gets position r). let full = Tensor::from_slice(&host, &[n, heads, hd]).to_device(Device::Cuda(0)); let roped_full = full .rope(theta, n) .to_device(Device::Cpu) .as_slice::() .to_vec(); let row_len = heads * hd; for t in 0..n { let row = &host[t * row_len..(t + 1) * row_len]; let roped_row = Tensor::from_slice(row, &[1, heads, hd]) .to_device(Device::Cuda(0)) .rope_at(theta, t) .to_device(Device::Cpu) .as_slice::() .to_vec(); let expect = &roped_full[t * row_len..(t + 1) * row_len]; assert_eq!( roped_row.as_slice(), expect, "rope_at(pos0={t}) != full rope row {t}" ); } println!("rope_at OK: bit-identical to full rope across {n} positions"); } /// (d) `decode_attention` (single query vs cached K/V, no mask) equals the LAST /// query row of the full causal `attention`. This is the core decode-engine /// invariant: the incremental path must reproduce what the full-recompute forward /// computes for the final position, so KV-cache greedy decode is token-identical. /// Tolerance is fp rounding (different softmax kernel + reduction order), not bits. #[test] fn decode_attention_matches_full_attention_last_row() { assert!( device::device_count().expect("device count") > 0, "no CUDA device" ); device::set_device(0).unwrap(); let (bh, t, hd) = (6usize, 5usize, 8usize); let scale = 1.0 / (hd as f32).sqrt(); let n = bh * t * hd; let qh: Vec = (0..n).map(|i| ((i * 31 % 97) as f32 / 48.0) - 1.0).collect(); let kh: Vec = (0..n).map(|i| ((i * 53 % 89) as f32 / 44.0) - 1.0).collect(); let vh: Vec = (0..n).map(|i| ((i * 17 % 83) as f32 / 41.0) - 1.0).collect(); let q = Tensor::from_slice(&qh, &[bh, t, hd]).to_device(Device::Cuda(0)); let k = Tensor::from_slice(&kh, &[bh, t, hd]).to_device(Device::Cuda(0)); let v = Tensor::from_slice(&vh, &[bh, t, hd]).to_device(Device::Cuda(0)); // Reference: full causal attention, take each head's last query row. let (full, _) = q.attention(&k, &v, scale); let full_h = full.to_device(Device::Cpu).as_slice::().to_vec(); // Decode: build Q_last [bh,1,hd] from each head's last row, attend to all K/V. let mut ql = vec![0f32; bh * hd]; for b in 0..bh { let src = (b * t + (t - 1)) * hd; ql[b * hd..(b + 1) * hd].copy_from_slice(&qh[src..src + hd]); } let q_last = Tensor::from_slice(&ql, &[bh, 1, hd]).to_device(Device::Cuda(0)); let dec = q_last .decode_attention(&k, &v, scale) .to_device(Device::Cpu) .as_slice::() .to_vec(); assert_eq!(dec.len(), bh * hd, "decode out shape"); let mut max_abs = 0f32; for b in 0..bh { for d in 0..hd { let got = dec[b * hd + d]; let exp = full_h[(b * t + (t - 1)) * hd + d]; max_abs = max_abs.max((got - exp).abs()); } } assert!( max_abs < 1e-4, "decode_attention vs full last-row max abs diff {max_abs} exceeds 1e-4" ); println!("decode_attention OK: matches full causal last row (bh={bh}, t={t}, max|Δ|={max_abs:.2e})"); } /// (e) `rope_pos` (per-row positions, M2b batched decode): with positions /// [0,1,…,n-1] it is bit-identical to the full-sequence `rope` (period=n); with a /// uniform position P every row matches `rope_at(·, P)` of that single row. This is /// the primitive the batched decode uses (G rows sharing one decode position). #[test] fn rope_pos_matches_rope_and_rope_at() { assert!(device::device_count().expect("device count") > 0, "no CUDA device"); device::set_device(0).unwrap(); let (n, heads, hd) = (7usize, 3usize, 8usize); let theta = 10000.0f32; let host: Vec = (0..n * heads * hd).map(|i| ((i * 37 % 101) as f32 / 50.0) - 1.0).collect(); let x = Tensor::from_slice(&host, &[n, heads, hd]).to_device(Device::Cuda(0)); // positions [0,1,…,n-1] ⇒ identical to the full-sequence rope. let seq_pos: Vec = (0..n as i32).collect(); let pos_t = Tensor::from_slice(&seq_pos, &[n]).to_device(Device::Cuda(0)); let got = x.rope_pos(&pos_t, theta).to_device(Device::Cpu).as_slice::().to_vec(); let want = x.rope(theta, n).to_device(Device::Cpu).as_slice::().to_vec(); assert_eq!(got, want, "rope_pos [0..n] != full rope"); // uniform position P ⇒ each row matches rope_at(single row, P). let p = 5i32; let uni = Tensor::from_slice(&vec![p; n], &[n]).to_device(Device::Cuda(0)); let got_u = x.rope_pos(&uni, theta).to_device(Device::Cpu).as_slice::().to_vec(); let row_len = heads * hd; for t in 0..n { let row = &host[t * row_len..(t + 1) * row_len]; let want_row = Tensor::from_slice(row, &[1, heads, hd]) .to_device(Device::Cuda(0)) .rope_at(theta, p as usize) .to_device(Device::Cpu) .as_slice::() .to_vec(); assert_eq!(&got_u[t * row_len..(t + 1) * row_len], want_row.as_slice(), "uniform pos row {t}"); } println!("rope_pos OK: == full rope for [0..n] and == rope_at(P) per row for uniform P"); } /// (f) `cat_seq` (device-side KV-cache append, M2c): concatenating [bh,ta,hd] ++ /// [bh,tb,hd] along the seq dim equals the host-side interleaved concat (per bh row, /// a's block then b's block). This is the device append that removes the M2a/M2b /// host round-trip. #[test] fn cat_seq_matches_host_concat() { assert!(device::device_count().expect("device count") > 0, "no CUDA device"); device::set_device(0).unwrap(); let (bh, ta, tb, hd) = (4usize, 3usize, 2usize, 5usize); let ah: Vec = (0..bh * ta * hd).map(|i| i as f32 * 0.1).collect(); let bhost: Vec = (0..bh * tb * hd).map(|i| -(i as f32) - 1.0).collect(); let a = Tensor::from_slice(&ah, &[bh, ta, hd]).to_device(Device::Cuda(0)); let b = Tensor::from_slice(&bhost, &[bh, tb, hd]).to_device(Device::Cuda(0)); let got = a.cat_seq(&b).to_device(Device::Cpu).as_slice::().to_vec(); // Host reference: per bh row, a's ta*hd then b's tb*hd. let mut want = vec![0f32; bh * (ta + tb) * hd]; for r in 0..bh { let (oa, ob, oo) = (r * ta * hd, r * tb * hd, r * (ta + tb) * hd); want[oo..oo + ta * hd].copy_from_slice(&ah[oa..oa + ta * hd]); want[oo + ta * hd..oo + (ta + tb) * hd].copy_from_slice(&bhost[ob..ob + tb * hd]); } assert_eq!(got, want, "cat_seq != host interleaved concat"); println!("cat_seq OK: [bh={bh},{ta}+{tb},{hd}] == host concat"); }