model: tensor-parallel Qwen3 (sharded weights + AllReduce)
from_weights_tp shards each rank's weights (column-split q/k/v/gate/up, row-split o/down; replicate norms/embed/lm_head) and the paged forward uses local head counts + AllReduces after o_proj and down_proj. PagedKVCache::new_tp sizes the pool for the rank's local KV heads (KV is sharded too). TP=1 is the identity path. New bench-tp binary runs E2E multi-GPU generation per TP degree. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -8,6 +8,7 @@ xserv-cuda = { path = "../xserv-cuda" }
|
||||
xserv-tensor = { path = "../xserv-tensor" }
|
||||
xserv-kernels = { path = "../xserv-kernels" }
|
||||
xserv-tokenizer = { path = "../xserv-tokenizer" }
|
||||
xserv-distributed = { path = "../xserv-distributed" }
|
||||
half.workspace = true
|
||||
smallvec.workspace = true
|
||||
serde.workspace = true
|
||||
|
||||
194
crates/xserv-model/src/bin/bench-tp.rs
Normal file
194
crates/xserv-model/src/bin/bench-tp.rs
Normal file
@@ -0,0 +1,194 @@
|
||||
//! Tensor-parallel E2E benchmark for Qwen3.
|
||||
//!
|
||||
//! Spawns one thread per TP rank (each bound to one GPU), loads the sharded
|
||||
//! model, and runs greedy autoregressive generation. Because lm_head is
|
||||
//! replicated and the post-AllReduce hidden state is identical on every rank,
|
||||
//! all ranks compute identical logits and pick the same greedy token — so the
|
||||
//! rank threads stay in lockstep via the per-layer AllReduces without any
|
||||
//! token broadcast. Rank 0 records output + timings.
|
||||
//!
|
||||
//! Usage: bench-tp <model-dir> [--tp N] [--gen-tokens N] [--devices 0,1,2,3]
|
||||
//!
|
||||
//! Run with --tp 1 / 2 / 4 and compare the printed text (correctness) and
|
||||
//! tok/s (performance).
|
||||
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
use std::time::Instant;
|
||||
|
||||
use xserv_model::qwen3::sample_greedy;
|
||||
use xserv_model::{loader, ModelConfig, PagedKVCache, Qwen3, BLOCK_SIZE};
|
||||
use xserv_tensor::{DType, Device};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
struct PromptResult {
|
||||
gen_ids: Vec<u32>,
|
||||
ttft_ms: f64,
|
||||
decode_tok_s: f64,
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
if args.len() < 2 {
|
||||
eprintln!("Usage: bench-tp <model-dir> [--tp N] [--gen-tokens N] [--devices 0,1,2,3]");
|
||||
std::process::exit(1);
|
||||
}
|
||||
let model_dir = PathBuf::from(&args[1]);
|
||||
let world: usize = arg(&args, "--tp").and_then(|s| s.parse().ok()).unwrap_or(1).max(1);
|
||||
let gen_tokens: usize = arg(&args, "--gen-tokens").and_then(|s| s.parse().ok()).unwrap_or(64);
|
||||
let devices: Vec<u32> = match arg(&args, "--devices") {
|
||||
Some(s) => s.split(',').filter_map(|d| d.trim().parse().ok()).collect(),
|
||||
None => (0..world as u32).collect(),
|
||||
};
|
||||
assert_eq!(devices.len(), world, "--devices count must equal --tp");
|
||||
|
||||
let config = ModelConfig::from_file(&model_dir.join("config.json"));
|
||||
assert!(
|
||||
config.num_kv_heads() % world == 0,
|
||||
"num_kv_heads {} not divisible by tp {world}",
|
||||
config.num_kv_heads()
|
||||
);
|
||||
let tokenizer = Tokenizer::from_file(&model_dir.join("tokenizer.json"));
|
||||
let eos = tokenizer.eos_token_id();
|
||||
|
||||
let prompts: Vec<&str> = vec![
|
||||
"The capital of France is",
|
||||
"Explain photosynthesis in one sentence.",
|
||||
"Write a haiku about the ocean.",
|
||||
"List three uses of a hammer.",
|
||||
"What is the speed of light?",
|
||||
"Describe the water cycle briefly.",
|
||||
"Who wrote Romeo and Juliet?",
|
||||
"Translate 'good morning' into Spanish.",
|
||||
];
|
||||
let prompt_ids: Vec<Vec<u32>> = prompts.iter().map(|p| tokenizer.encode(p)).collect();
|
||||
|
||||
// Tensors are not Send (their Storage holds a raw GPU pointer), so each rank
|
||||
// thread loads its own CPU copy of the weights and shards in-thread. Loading
|
||||
// is not part of the timed region.
|
||||
let id = if world > 1 { Some(xserv_distributed::get_unique_id()) } else { None };
|
||||
|
||||
let handles: Vec<_> = (0..world)
|
||||
.map(|rank| {
|
||||
let model_dir = model_dir.clone();
|
||||
let config = config.clone();
|
||||
let prompt_ids = prompt_ids.clone();
|
||||
let device = devices[rank];
|
||||
thread::spawn(move || {
|
||||
run_rank(rank, world, device, id, config, model_dir, prompt_ids, gen_tokens, eos)
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut rank0: Option<Vec<PromptResult>> = None;
|
||||
for (rank, h) in handles.into_iter().enumerate() {
|
||||
let r = h.join().expect("rank thread panicked");
|
||||
if rank == 0 {
|
||||
rank0 = r;
|
||||
}
|
||||
}
|
||||
|
||||
let results = rank0.expect("rank 0 produced no results");
|
||||
println!("\n=== TP={world} (devices {devices:?}) — Qwen3 E2E benchmark ===");
|
||||
println!("{:<45} {:>10} {:>12} {:>8}", "prompt", "TTFT(ms)", "decode tok/s", "gen");
|
||||
let mut tps_sum = 0.0;
|
||||
for (i, r) in results.iter().enumerate() {
|
||||
let text = tokenizer.decode(&r.gen_ids).replace('\n', " ");
|
||||
let short: String = text.chars().take(50).collect();
|
||||
let p: String = prompts[i].chars().take(43).collect();
|
||||
println!(
|
||||
"{:<45} {:>10.1} {:>12.1} {:>8} | {}",
|
||||
p, r.ttft_ms, r.decode_tok_s, r.gen_ids.len(), short
|
||||
);
|
||||
tps_sum += r.decode_tok_s;
|
||||
}
|
||||
println!("--- mean decode throughput: {:.1} tok/s ---", tps_sum / results.len() as f64);
|
||||
|
||||
// Machine-readable line for cross-TP correctness diffing (rank 0 token ids).
|
||||
let all_ids: Vec<String> = results
|
||||
.iter()
|
||||
.map(|r| r.gen_ids.iter().map(|x| x.to_string()).collect::<Vec<_>>().join(","))
|
||||
.collect();
|
||||
println!("CORRECTNESS_IDS tp={world} {}", all_ids.join(" | "));
|
||||
}
|
||||
|
||||
fn run_rank(
|
||||
rank: usize,
|
||||
world: usize,
|
||||
device: u32,
|
||||
id: Option<xserv_distributed::UniqueId>,
|
||||
config: ModelConfig,
|
||||
model_dir: PathBuf,
|
||||
prompt_ids: Vec<Vec<u32>>,
|
||||
gen_tokens: usize,
|
||||
eos: Option<u32>,
|
||||
) -> Option<Vec<PromptResult>> {
|
||||
// Bind this thread to its GPU and set up the TP communicator.
|
||||
let tp = if world > 1 {
|
||||
Some(Arc::new(xserv_distributed::TpContext::init(rank, world, id.unwrap(), device)))
|
||||
} else {
|
||||
xserv_cuda::device::set_device(device).unwrap();
|
||||
None
|
||||
};
|
||||
|
||||
// Load this rank's own CPU copy of the weights and shard in-thread.
|
||||
let weights = loader::load_model_dir(&model_dir, Device::Cpu);
|
||||
let model = Qwen3::from_weights_tp(config.clone(), weights, rank, world, device, tp.clone());
|
||||
|
||||
// Per-rank paged KV cache holds only this rank's local KV heads.
|
||||
let local_kv = config.num_kv_heads() / world;
|
||||
let max_seq = 2048usize;
|
||||
let max_blocks_per_seq = max_seq.div_ceil(BLOCK_SIZE);
|
||||
let total_blocks = max_blocks_per_seq + 8;
|
||||
let mut cache = PagedKVCache::new_tp(
|
||||
&config, local_kv, total_blocks, 0, 1, max_blocks_per_seq, DType::BF16, device,
|
||||
);
|
||||
|
||||
// Warmup (init kernels / allocator / NCCL channels) — not timed.
|
||||
cache.register_sequence(0).unwrap();
|
||||
let _ = model.forward_prefill_paged(&[1u32, 2, 3], 0, &mut cache);
|
||||
cache.free_sequence(0);
|
||||
|
||||
let mut out = Vec::new();
|
||||
for ids in &prompt_ids {
|
||||
cache.register_sequence(0).unwrap();
|
||||
|
||||
// Prefill (TTFT).
|
||||
let t0 = Instant::now();
|
||||
let logits = model.forward_prefill_paged(ids, 0, &mut cache);
|
||||
let first = sample_greedy(&logits);
|
||||
let ttft_ms = t0.elapsed().as_secs_f64() * 1000.0;
|
||||
|
||||
let mut generated = vec![first];
|
||||
|
||||
// Decode.
|
||||
let t1 = Instant::now();
|
||||
let mut steps = 0usize;
|
||||
for _ in 1..gen_tokens {
|
||||
let last = *generated.last().unwrap();
|
||||
if eos == Some(last) {
|
||||
break;
|
||||
}
|
||||
let pos = cache.seq_len(0);
|
||||
let logits = model.forward_decode_paged(&[last], &[pos], &[0], &mut cache);
|
||||
let next = sample_greedy(&logits);
|
||||
generated.push(next);
|
||||
steps += 1;
|
||||
}
|
||||
let decode_s = t1.elapsed().as_secs_f64();
|
||||
let decode_tok_s = if steps > 0 && decode_s > 0.0 { steps as f64 / decode_s } else { 0.0 };
|
||||
|
||||
cache.free_sequence(0);
|
||||
|
||||
if rank == 0 {
|
||||
out.push(PromptResult { gen_ids: generated, ttft_ms, decode_tok_s });
|
||||
}
|
||||
}
|
||||
|
||||
if rank == 0 { Some(out) } else { None }
|
||||
}
|
||||
|
||||
fn arg<'a>(args: &'a [String], flag: &str) -> Option<&'a str> {
|
||||
args.iter().position(|a| a == flag).and_then(|i| args.get(i + 1)).map(|s| s.as_str())
|
||||
}
|
||||
@@ -134,10 +134,29 @@ impl PagedKVCache {
|
||||
max_blocks_per_seq: usize,
|
||||
dtype: DType,
|
||||
device: u32,
|
||||
) -> Self {
|
||||
Self::new_tp(
|
||||
config, config.num_kv_heads(), total_blocks, cpu_total_blocks,
|
||||
max_seqs, max_blocks_per_seq, dtype, device,
|
||||
)
|
||||
}
|
||||
|
||||
/// Like `new`, but with an explicit `num_kv_heads` — under tensor parallelism
|
||||
/// each rank only stores its `num_kv_heads / world` heads, so the pool is
|
||||
/// sized for the local head count, not the model's full count.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new_tp(
|
||||
config: &ModelConfig,
|
||||
num_kv_heads: usize,
|
||||
total_blocks: usize,
|
||||
cpu_total_blocks: usize,
|
||||
max_seqs: usize,
|
||||
max_blocks_per_seq: usize,
|
||||
dtype: DType,
|
||||
device: u32,
|
||||
) -> Self {
|
||||
assert!(total_blocks >= 2, "need at least 2 blocks (one is sentinel)");
|
||||
let num_layers = config.num_layers();
|
||||
let num_kv_heads = config.num_kv_heads();
|
||||
let head_dim = config.head_dim();
|
||||
let elem_size = dtype.size_bytes();
|
||||
let block_bytes = num_kv_heads * BLOCK_SIZE * head_dim * elem_size;
|
||||
|
||||
@@ -15,6 +15,11 @@ pub struct Qwen3 {
|
||||
norm: Tensor,
|
||||
lm_head_t: Tensor, // precomputed transpose
|
||||
rope_cache: RopeCache,
|
||||
// Tensor parallelism. `tp` is None (or world==1) for single-GPU; otherwise
|
||||
// this rank holds 1/world of the heads and AllReduces after o_proj/down_proj.
|
||||
tp: Option<std::sync::Arc<xserv_distributed::TpContext>>,
|
||||
local_num_heads: usize, // = num_heads / world
|
||||
local_num_kv_heads: usize, // = num_kv_heads / world
|
||||
}
|
||||
|
||||
struct Qwen3Block {
|
||||
@@ -32,15 +37,43 @@ struct Qwen3Block {
|
||||
}
|
||||
|
||||
impl Qwen3 {
|
||||
pub fn from_weights(config: ModelConfig, mut w: HashMap<String, Tensor>) -> Self {
|
||||
/// Single-GPU load (weights already on the target GPU). Equivalent to
|
||||
/// `from_weights_tp(.., rank=0, world=1, device=0, tp=None)`.
|
||||
pub fn from_weights(config: ModelConfig, w: HashMap<String, Tensor>) -> Self {
|
||||
Self::from_weights_tp(config, w, 0, 1, 0, None)
|
||||
}
|
||||
|
||||
/// Tensor-parallel load. `w` may live on CPU or any device; each weight is
|
||||
/// sharded for `rank`/`world`, uploaded to `device`, and transposed.
|
||||
/// `world==1` shards are identity, so this is also the single-GPU path.
|
||||
///
|
||||
/// Split scheme (Megatron-style):
|
||||
/// - column-parallel (split output): q/k/v/gate/up → shard rows of `[out,in]`
|
||||
/// - row-parallel (split input): o/down → shard cols of `[out,in]`
|
||||
/// - replicated: norms, embed_tokens, lm_head
|
||||
pub fn from_weights_tp(
|
||||
config: ModelConfig,
|
||||
mut w: HashMap<String, Tensor>,
|
||||
rank: usize,
|
||||
world: usize,
|
||||
device: u32,
|
||||
tp: Option<std::sync::Arc<xserv_distributed::TpContext>>,
|
||||
) -> Self {
|
||||
crate::init_kernels();
|
||||
let dev = Device::Cuda(device);
|
||||
let take = |w: &mut HashMap<String, Tensor>, name: &str| -> Tensor {
|
||||
w.remove(name).unwrap_or_else(|| panic!("missing weight: {name}"))
|
||||
};
|
||||
// Replicated weight: upload whole to this rank's device.
|
||||
let repl = |t: Tensor| -> Tensor { t.to_device(dev) };
|
||||
// column-parallel: keep this rank's rows of [out, in], upload, transpose → [in, out/world].
|
||||
let col = |t: Tensor| -> Tensor { shard_rows(&t, rank, world).to_device(dev).transpose(0, 1).contiguous() };
|
||||
// row-parallel: keep this rank's cols of [out, in], upload, transpose → [in/world, out].
|
||||
let row = |t: Tensor| -> Tensor { shard_cols(&t, rank, world).to_device(dev).transpose(0, 1).contiguous() };
|
||||
|
||||
let embed_tokens = take(&mut w, "model.embed_tokens.weight");
|
||||
let norm = take(&mut w, "model.norm.weight");
|
||||
let lm_head_raw = take(&mut w, "lm_head.weight");
|
||||
let embed_tokens = repl(take(&mut w, "model.embed_tokens.weight"));
|
||||
let norm = repl(take(&mut w, "model.norm.weight"));
|
||||
let lm_head_t = repl(take(&mut w, "lm_head.weight")).transpose(0, 1).contiguous();
|
||||
|
||||
let rope_cache = RopeCache::new(
|
||||
config.max_seq_len(),
|
||||
@@ -48,33 +81,51 @@ impl Qwen3 {
|
||||
config.rope_theta.unwrap_or(1_000_000.0) as f32,
|
||||
);
|
||||
|
||||
// Precompute transposed weights: [out, in] → [in, out] so we can do x @ wt directly
|
||||
let transpose_w = |t: Tensor| -> Tensor {
|
||||
t.transpose(0, 1).contiguous()
|
||||
};
|
||||
|
||||
let num_layers = config.num_layers();
|
||||
let mut layers = Vec::with_capacity(num_layers);
|
||||
eprintln!("Transposing weights for {} layers...", num_layers);
|
||||
if rank == 0 {
|
||||
eprintln!("Loading+sharding weights for {} layers (world={world})...", num_layers);
|
||||
}
|
||||
for i in 0..num_layers {
|
||||
let p = format!("model.layers.{i}");
|
||||
layers.push(Qwen3Block {
|
||||
input_norm: take(&mut w, &format!("{p}.input_layernorm.weight")),
|
||||
q_proj_wt: transpose_w(take(&mut w, &format!("{p}.self_attn.q_proj.weight"))),
|
||||
k_proj_wt: transpose_w(take(&mut w, &format!("{p}.self_attn.k_proj.weight"))),
|
||||
v_proj_wt: transpose_w(take(&mut w, &format!("{p}.self_attn.v_proj.weight"))),
|
||||
o_proj_wt: transpose_w(take(&mut w, &format!("{p}.self_attn.o_proj.weight"))),
|
||||
q_norm: take(&mut w, &format!("{p}.self_attn.q_norm.weight")),
|
||||
k_norm: take(&mut w, &format!("{p}.self_attn.k_norm.weight")),
|
||||
post_norm: take(&mut w, &format!("{p}.post_attention_layernorm.weight")),
|
||||
gate_proj_wt: transpose_w(take(&mut w, &format!("{p}.mlp.gate_proj.weight"))),
|
||||
up_proj_wt: transpose_w(take(&mut w, &format!("{p}.mlp.up_proj.weight"))),
|
||||
down_proj_wt: transpose_w(take(&mut w, &format!("{p}.mlp.down_proj.weight"))),
|
||||
input_norm: repl(take(&mut w, &format!("{p}.input_layernorm.weight"))),
|
||||
q_proj_wt: col(take(&mut w, &format!("{p}.self_attn.q_proj.weight"))),
|
||||
k_proj_wt: col(take(&mut w, &format!("{p}.self_attn.k_proj.weight"))),
|
||||
v_proj_wt: col(take(&mut w, &format!("{p}.self_attn.v_proj.weight"))),
|
||||
o_proj_wt: row(take(&mut w, &format!("{p}.self_attn.o_proj.weight"))),
|
||||
q_norm: repl(take(&mut w, &format!("{p}.self_attn.q_norm.weight"))),
|
||||
k_norm: repl(take(&mut w, &format!("{p}.self_attn.k_norm.weight"))),
|
||||
post_norm: repl(take(&mut w, &format!("{p}.post_attention_layernorm.weight"))),
|
||||
gate_proj_wt: col(take(&mut w, &format!("{p}.mlp.gate_proj.weight"))),
|
||||
up_proj_wt: col(take(&mut w, &format!("{p}.mlp.up_proj.weight"))),
|
||||
down_proj_wt: row(take(&mut w, &format!("{p}.mlp.down_proj.weight"))),
|
||||
});
|
||||
}
|
||||
|
||||
let lm_head_t = transpose_w(lm_head_raw);
|
||||
Self { config, embed_tokens, layers, norm, lm_head_t, rope_cache }
|
||||
Self {
|
||||
local_num_heads: config.num_heads() / world,
|
||||
local_num_kv_heads: config.num_kv_heads() / world,
|
||||
config,
|
||||
embed_tokens,
|
||||
layers,
|
||||
norm,
|
||||
lm_head_t,
|
||||
rope_cache,
|
||||
tp,
|
||||
}
|
||||
}
|
||||
|
||||
/// In-place AllReduce(sum) of a partial `[*, hidden]` BF16 activation across
|
||||
/// TP ranks (no-op when not tensor-parallel). Used after o_proj and down_proj.
|
||||
#[inline]
|
||||
fn all_reduce(&self, t: &Tensor) {
|
||||
if let Some(tp) = &self.tp {
|
||||
if tp.world > 1 {
|
||||
let ptr = t.storage().gpu_buffer().as_ptr() as *mut std::ffi::c_void;
|
||||
tp.all_reduce_sum_bf16_ptr(ptr, t.numel());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward_with_cache(&self, token_ids: &[u32], cache: &mut KVCache) -> Tensor {
|
||||
@@ -273,8 +324,9 @@ impl Qwen3 {
|
||||
assert_eq!(seq_slots.len(), batch);
|
||||
assert!(batch > 0);
|
||||
|
||||
let num_heads = self.config.num_heads();
|
||||
let num_kv_heads = self.config.num_kv_heads();
|
||||
// TP: this rank owns a slice of the heads (local_* == full when world==1).
|
||||
let num_heads = self.local_num_heads;
|
||||
let num_kv_heads = self.local_num_kv_heads;
|
||||
let head_dim = self.config.head_dim();
|
||||
let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32;
|
||||
|
||||
@@ -356,6 +408,7 @@ impl Qwen3 {
|
||||
// Plain reshape is a view; merge_heads_gpu would incorrectly swap B<->H.
|
||||
let attn_merged = attn_out.reshape(&[batch, num_heads * head_dim]);
|
||||
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
|
||||
self.all_reduce(&attn_proj); // TP: sum partial attention outputs
|
||||
|
||||
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let residual = x_new.clone();
|
||||
@@ -364,6 +417,7 @@ impl Qwen3 {
|
||||
let up = matmul_2d(&normed, &layer.up_proj_wt);
|
||||
let hidden_states = xserv_kernels::silu_mul(&gate, &up);
|
||||
let down = matmul_2d(&hidden_states, &layer.down_proj_wt);
|
||||
self.all_reduce(&down); // TP: sum partial MLP outputs
|
||||
x = add_any(&residual, &down);
|
||||
}
|
||||
|
||||
@@ -387,8 +441,9 @@ impl Qwen3 {
|
||||
) -> Tensor {
|
||||
let new_tokens = token_ids.len();
|
||||
let pos_offset = paged_cache.seq_len(slot);
|
||||
let num_heads = self.config.num_heads();
|
||||
let num_kv_heads = self.config.num_kv_heads();
|
||||
// TP: this rank owns a slice of the heads (local_* == full when world==1).
|
||||
let num_heads = self.local_num_heads;
|
||||
let num_kv_heads = self.local_num_kv_heads;
|
||||
let head_dim = self.config.head_dim();
|
||||
let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32;
|
||||
|
||||
@@ -431,6 +486,7 @@ impl Qwen3 {
|
||||
|
||||
let attn_merged = xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
|
||||
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
|
||||
self.all_reduce(&attn_proj); // TP: sum partial attention outputs
|
||||
|
||||
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let residual = x_new.clone();
|
||||
@@ -439,6 +495,7 @@ impl Qwen3 {
|
||||
let up = matmul_2d(&normed, &layer.up_proj_wt);
|
||||
let hidden_states = xserv_kernels::silu_mul(&gate, &up);
|
||||
let down = matmul_2d(&hidden_states, &layer.down_proj_wt);
|
||||
self.all_reduce(&down); // TP: sum partial MLP outputs
|
||||
x = add_any(&residual, &down);
|
||||
}
|
||||
|
||||
@@ -549,6 +606,43 @@ impl Qwen3 {
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
/// Keep this rank's contiguous row-block of a 2D `[rows, cols]` BF16 tensor
|
||||
/// (column-parallel split: split the OUTPUT dim). `world==1` returns the whole.
|
||||
/// Input must be a contiguous CPU (or device) BF16 tensor.
|
||||
fn shard_rows(t: &Tensor, rank: usize, world: usize) -> Tensor {
|
||||
if world == 1 { return t.clone(); }
|
||||
let shape = t.shape();
|
||||
assert_eq!(shape.len(), 2, "shard_rows expects 2D weight");
|
||||
let (rows, cols) = (shape[0], shape[1]);
|
||||
assert!(rows % world == 0, "rows {rows} not divisible by world {world}");
|
||||
let local = rows / world;
|
||||
let host = t.to_device(Device::Cpu);
|
||||
let data = host.as_slice::<bf16>();
|
||||
let start = rank * local * cols;
|
||||
let shard = data[start..start + local * cols].to_vec();
|
||||
Tensor::from_slice(&shard, &[local, cols])
|
||||
}
|
||||
|
||||
/// Keep this rank's column-block of a 2D `[rows, cols]` BF16 tensor (row-parallel
|
||||
/// split: split the INPUT dim). Strided copy. `world==1` returns the whole.
|
||||
fn shard_cols(t: &Tensor, rank: usize, world: usize) -> Tensor {
|
||||
if world == 1 { return t.clone(); }
|
||||
let shape = t.shape();
|
||||
assert_eq!(shape.len(), 2, "shard_cols expects 2D weight");
|
||||
let (rows, cols) = (shape[0], shape[1]);
|
||||
assert!(cols % world == 0, "cols {cols} not divisible by world {world}");
|
||||
let local = cols / world;
|
||||
let c0 = rank * local;
|
||||
let host = t.to_device(Device::Cpu);
|
||||
let data = host.as_slice::<bf16>();
|
||||
let mut shard = Vec::with_capacity(rows * local);
|
||||
for r in 0..rows {
|
||||
let base = r * cols + c0;
|
||||
shard.extend_from_slice(&data[base..base + local]);
|
||||
}
|
||||
Tensor::from_slice(&shard, &[rows, local])
|
||||
}
|
||||
|
||||
fn matmul_2d(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
assert_eq!(a.ndim(), 2);
|
||||
assert_eq!(b.ndim(), 2);
|
||||
|
||||
Reference in New Issue
Block a user