690 lines
29 KiB
Python
690 lines
29 KiB
Python
"""
|
|
Benchmark latency comparison between different all-reduce implementations.
|
|
|
|
Compares:
|
|
- NCCL all-reduce (may be non-deterministic)
|
|
- Reduce-scatter + all-gather (RS+AG, deterministic but slower)
|
|
- Deterministic 1-stage kernel (forces fixed accumulation order, deterministic)
|
|
|
|
Note: The "deterministic kernel" is NOT RS+AG. It uses the 1-stage kernel where
|
|
each GPU reads all data from all GPUs and reduces locally in a fixed order.
|
|
|
|
Usage:
|
|
python bench_amd_deterministic_allreduce.py
|
|
"""
|
|
|
|
import multiprocessing as mp
|
|
import os
|
|
import socket
|
|
import statistics
|
|
import sys
|
|
import time
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
# Add python directory to path to import sglang modules
|
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
python_dir = os.path.join(script_dir, "python")
|
|
sys.path.insert(0, python_dir)
|
|
|
|
# Try to import custom all-reduce if available
|
|
from sglang.srt.environ import envs
|
|
|
|
try:
|
|
import sglang.srt.distributed.device_communicators.custom_all_reduce_ops as custom_ar_ops
|
|
from sglang.srt.distributed.device_communicators.custom_all_reduce import (
|
|
CustomAllreduce,
|
|
)
|
|
|
|
CUSTOM_AR_AVAILABLE = custom_ar_ops.IS_CUSTOM_AR_AVAILABLE
|
|
except (ImportError, AttributeError):
|
|
CUSTOM_AR_AVAILABLE = False
|
|
CustomAllreduce = None
|
|
|
|
# Note: sglang's optimized all-reduce requires full runtime initialization
|
|
# and won't work in standalone benchmarks, so we skip it
|
|
SGLANG_AVAILABLE = False
|
|
|
|
|
|
def get_open_port():
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
s.bind(("127.0.0.1", 0))
|
|
return s.getsockname()[1]
|
|
|
|
|
|
def init_custom_ar_if_available(rank, world_size, device):
|
|
"""Check if custom all-reduce is available and applicable."""
|
|
if not CUSTOM_AR_AVAILABLE or CustomAllreduce is None:
|
|
return False
|
|
|
|
# Custom AR works best for single-node, even number of GPUs, world_size <= 8
|
|
if world_size <= 8 and world_size % 2 == 0:
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def reduce_scatter_then_all_gather(tensor, rank, world_size, custom_ar=None):
|
|
"""
|
|
Deterministic all-reduce using reduce-scatter + all-gather.
|
|
This is deterministic because it uses fixed ordering (no atomics).
|
|
"""
|
|
total_size = tensor.numel()
|
|
if total_size % world_size != 0:
|
|
# Fallback to all-gather + local reduce if not divisible
|
|
gather_list = [torch.empty_like(tensor) for _ in range(world_size)]
|
|
dist.all_gather(gather_list, tensor)
|
|
stacked = torch.stack(gather_list, dim=0)
|
|
tensor.copy_(stacked.sum(dim=0))
|
|
return
|
|
|
|
chunk_size = total_size // world_size
|
|
|
|
# Flatten to 1D
|
|
tensor_flat = tensor.view(-1)
|
|
|
|
# Reduce-scatter: each rank gets its chunk of the reduced result
|
|
output_chunk = torch.empty(chunk_size, dtype=tensor.dtype, device=tensor.device)
|
|
|
|
# Split input into chunks for reduce-scatter
|
|
input_chunks = [
|
|
tensor_flat[i * chunk_size : (i + 1) * chunk_size].clone()
|
|
for i in range(world_size)
|
|
]
|
|
|
|
dist.reduce_scatter(output_chunk, input_chunks)
|
|
|
|
# All-gather: broadcast each rank's chunk to all ranks
|
|
output_chunks = [
|
|
torch.empty(chunk_size, dtype=tensor.dtype, device=tensor.device)
|
|
for _ in range(world_size)
|
|
]
|
|
dist.all_gather(output_chunks, output_chunk)
|
|
|
|
# Concatenate results back
|
|
result_flat = torch.cat(output_chunks, dim=0)
|
|
tensor.copy_(result_flat.view(tensor.shape))
|
|
|
|
|
|
def worker(world_size, rank, port, results_queue):
|
|
envs.SGLANG_USE_1STAGE_ALLREDUCE.set("1")
|
|
device = torch.device(f"cuda:{rank}")
|
|
torch.cuda.set_device(device)
|
|
|
|
dist.init_process_group(
|
|
backend="nccl",
|
|
init_method=f"tcp://localhost:{port}",
|
|
rank=rank,
|
|
world_size=world_size,
|
|
)
|
|
|
|
# Try to initialize custom all-reduce if available
|
|
custom_ar = None
|
|
use_custom_ar = init_custom_ar_if_available(rank, world_size, device)
|
|
if use_custom_ar and CUSTOM_AR_AVAILABLE:
|
|
try:
|
|
# Create a gloo group for custom AR (it requires non-NCCL backend)
|
|
# All ranks must call new_group with the same parameters
|
|
from torch.distributed import new_group
|
|
|
|
dist.barrier() # Ensure all ranks are ready
|
|
ar_group = new_group(backend="gloo")
|
|
dist.barrier() # Ensure group creation is complete
|
|
custom_ar = CustomAllreduce(group=ar_group, device=device)
|
|
if rank == 0:
|
|
print(" Using custom all-reduce (deterministic)")
|
|
except Exception as e:
|
|
if rank == 0:
|
|
print(f" Custom AR init failed: {e}, using NCCL fallback")
|
|
custom_ar = None
|
|
dist.barrier() # Ensure all ranks continue even if one fails
|
|
|
|
# Test different batch sizes - similar to test_ar.py
|
|
batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256, 512]
|
|
hidden_dim = 16384 # Fixed hidden dimension
|
|
|
|
num_trials = 10 # Same as test_ar.py
|
|
|
|
# Different seed per rank - each GPU has DIFFERENT input (like test_ar.py)
|
|
torch.manual_seed(42 + rank)
|
|
|
|
results = {}
|
|
|
|
for bs in batch_sizes:
|
|
# Create fixed input for all trials (like test_ar.py)
|
|
base_input = torch.randn(bs, hidden_dim, dtype=torch.bfloat16, device=device)
|
|
|
|
dist.barrier()
|
|
|
|
if rank == 0:
|
|
print(f"\nBatch size {bs:4d}:")
|
|
print(f" Testing determinism across {num_trials} trials...")
|
|
|
|
# Test all-reduce determinism
|
|
results_ar = []
|
|
latencies_ar = []
|
|
for trial in range(num_trials):
|
|
# Clone the same input for each trial
|
|
inp_ar = base_input.clone()
|
|
inp_flat_ar = inp_ar.view(-1)
|
|
|
|
# Measure latency
|
|
torch.cuda.synchronize()
|
|
start = time.perf_counter()
|
|
dist.all_reduce(inp_flat_ar, op=dist.ReduceOp.SUM)
|
|
torch.cuda.synchronize()
|
|
end = time.perf_counter()
|
|
latencies_ar.append(end - start)
|
|
|
|
# Store checksum and first values (like test_ar.py)
|
|
checksum = inp_flat_ar.sum().item()
|
|
first_vals = inp_flat_ar[:5].clone()
|
|
results_ar.append((checksum, first_vals))
|
|
|
|
# Test reduce-scatter + all-gather determinism
|
|
results_rs_ag = []
|
|
latencies_rs_ag = []
|
|
for trial in range(num_trials):
|
|
# Clone the same input for each trial
|
|
inp_rs_ag = base_input.clone()
|
|
inp_flat_rs_ag = inp_rs_ag.view(-1)
|
|
|
|
# Measure latency
|
|
torch.cuda.synchronize()
|
|
start = time.perf_counter()
|
|
reduce_scatter_then_all_gather(
|
|
inp_flat_rs_ag, rank, world_size, custom_ar=None
|
|
)
|
|
torch.cuda.synchronize()
|
|
end = time.perf_counter()
|
|
latencies_rs_ag.append(end - start)
|
|
|
|
# Store checksum and first values (like test_ar.py)
|
|
checksum = inp_flat_rs_ag.sum().item()
|
|
first_vals = inp_flat_rs_ag[:5].clone()
|
|
results_rs_ag.append((checksum, first_vals))
|
|
|
|
# Note: sglang's optimized all-reduce requires full runtime initialization
|
|
# and is not tested in this standalone benchmark
|
|
use_sglang_optimized = False
|
|
results_optimized_rs_ag = []
|
|
latencies_optimized_rs_ag = []
|
|
|
|
# Test custom all-reduce determinism (if available)
|
|
results_custom_ar = []
|
|
latencies_custom_ar = []
|
|
if custom_ar is not None:
|
|
for trial in range(num_trials):
|
|
# Clone the same input for each trial
|
|
inp_custom = base_input.clone()
|
|
inp_flat_custom = inp_custom.view(-1)
|
|
|
|
# Measure latency
|
|
torch.cuda.synchronize()
|
|
start = time.perf_counter()
|
|
reduce_scatter_then_all_gather(
|
|
inp_flat_custom, rank, world_size, custom_ar=custom_ar
|
|
)
|
|
torch.cuda.synchronize()
|
|
end = time.perf_counter()
|
|
latencies_custom_ar.append(end - start)
|
|
|
|
# Store checksum and first values (like test_ar.py)
|
|
checksum = inp_flat_custom.sum().item()
|
|
first_vals = inp_flat_custom[:5].clone()
|
|
results_custom_ar.append((checksum, first_vals))
|
|
|
|
# Test deterministic kernel (if available)
|
|
results_deterministic_kernel = []
|
|
latencies_deterministic_kernel = []
|
|
deterministic_kernel_available = False
|
|
if custom_ar is not None:
|
|
# Check if input size fits in buffer
|
|
input_size_bytes = base_input.numel() * base_input.element_size()
|
|
if input_size_bytes > custom_ar.max_size:
|
|
if rank == 0:
|
|
print(
|
|
f" Deterministic kernel skipped: input size ({input_size_bytes/(1024*1024):.1f} MB) > buffer size ({custom_ar.max_size/(1024*1024):.1f} MB)"
|
|
)
|
|
deterministic_kernel_available = False
|
|
else:
|
|
try:
|
|
deterministic_kernel_available = True
|
|
for trial in range(num_trials):
|
|
# Clone the same input for each trial
|
|
inp_kernel = base_input.clone()
|
|
|
|
# Measure latency
|
|
torch.cuda.synchronize()
|
|
start = time.perf_counter()
|
|
result_kernel = custom_ar.custom_all_reduce(inp_kernel)
|
|
torch.cuda.synchronize()
|
|
end = time.perf_counter()
|
|
latencies_deterministic_kernel.append(end - start)
|
|
|
|
# Store checksum and first values
|
|
result_flat_kernel = result_kernel.view(-1)
|
|
checksum = result_flat_kernel.sum().item()
|
|
first_vals = result_flat_kernel[:5].clone()
|
|
results_deterministic_kernel.append((checksum, first_vals))
|
|
except Exception as e:
|
|
if rank == 0:
|
|
print(
|
|
f" Deterministic kernel test failed for batch size {bs}: {e}"
|
|
)
|
|
deterministic_kernel_available = False
|
|
|
|
dist.barrier()
|
|
|
|
if rank == 0:
|
|
# Check determinism for all-reduce
|
|
ar_deterministic = True
|
|
ar_ref_sum, ar_ref_vals = results_ar[0]
|
|
ar_variance = []
|
|
for i, (s, vals) in enumerate(results_ar[1:], 1):
|
|
if abs(ar_ref_sum - s) > 1e-3 or not torch.allclose(
|
|
ar_ref_vals, vals, rtol=1e-3
|
|
):
|
|
ar_deterministic = False
|
|
ar_variance.append(abs(ar_ref_sum - s))
|
|
|
|
# Check determinism for reduce-scatter + all-gather
|
|
rs_ag_deterministic = True
|
|
rs_ag_ref_sum, rs_ag_ref_vals = results_rs_ag[0]
|
|
rs_ag_variance = []
|
|
for i, (s, vals) in enumerate(results_rs_ag[1:], 1):
|
|
if abs(rs_ag_ref_sum - s) > 1e-3 or not torch.allclose(
|
|
rs_ag_ref_vals, vals, rtol=1e-3
|
|
):
|
|
rs_ag_deterministic = False
|
|
rs_ag_variance.append(abs(rs_ag_ref_sum - s))
|
|
|
|
# Check determinism for optimized RS+AG (if available)
|
|
optimized_rs_ag_deterministic = None
|
|
optimized_rs_ag_max_variance = None
|
|
lat_optimized_rs_ag_median = None
|
|
if use_sglang_optimized and results_optimized_rs_ag:
|
|
optimized_rs_ag_deterministic = True
|
|
opt_rs_ag_ref_sum, opt_rs_ag_ref_vals = results_optimized_rs_ag[0]
|
|
opt_rs_ag_variance = []
|
|
for i, (s, vals) in enumerate(results_optimized_rs_ag[1:], 1):
|
|
if abs(opt_rs_ag_ref_sum - s) > 1e-3 or not torch.allclose(
|
|
opt_rs_ag_ref_vals, vals, rtol=1e-3
|
|
):
|
|
optimized_rs_ag_deterministic = False
|
|
opt_rs_ag_variance.append(abs(opt_rs_ag_ref_sum - s))
|
|
optimized_rs_ag_max_variance = (
|
|
max(opt_rs_ag_variance) if opt_rs_ag_variance else 0.0
|
|
)
|
|
lat_optimized_rs_ag_median = statistics.median(
|
|
latencies_optimized_rs_ag
|
|
)
|
|
|
|
# Check determinism for custom all-reduce (if available)
|
|
custom_ar_deterministic = None
|
|
custom_ar_max_variance = None
|
|
lat_custom_ar_median = None
|
|
if custom_ar is not None and results_custom_ar:
|
|
custom_ar_deterministic = True
|
|
custom_ar_ref_sum, custom_ar_ref_vals = results_custom_ar[0]
|
|
custom_ar_variance = []
|
|
for i, (s, vals) in enumerate(results_custom_ar[1:], 1):
|
|
if abs(custom_ar_ref_sum - s) > 1e-3 or not torch.allclose(
|
|
custom_ar_ref_vals, vals, rtol=1e-3
|
|
):
|
|
custom_ar_deterministic = False
|
|
custom_ar_variance.append(abs(custom_ar_ref_sum - s))
|
|
custom_ar_max_variance = (
|
|
max(custom_ar_variance) if custom_ar_variance else 0.0
|
|
)
|
|
lat_custom_ar_median = statistics.median(latencies_custom_ar)
|
|
|
|
# Check determinism for deterministic kernel (if available)
|
|
deterministic_kernel_deterministic = None
|
|
deterministic_kernel_max_variance = None
|
|
lat_deterministic_kernel_median = None
|
|
if deterministic_kernel_available and results_deterministic_kernel:
|
|
deterministic_kernel_deterministic = True
|
|
kernel_ref_sum, kernel_ref_vals = results_deterministic_kernel[0]
|
|
kernel_variance = []
|
|
for i, (s, vals) in enumerate(results_deterministic_kernel[1:], 1):
|
|
if abs(kernel_ref_sum - s) > 1e-3 or not torch.allclose(
|
|
kernel_ref_vals, vals, rtol=1e-3
|
|
):
|
|
deterministic_kernel_deterministic = False
|
|
kernel_variance.append(abs(kernel_ref_sum - s))
|
|
deterministic_kernel_max_variance = (
|
|
max(kernel_variance) if kernel_variance else 0.0
|
|
)
|
|
lat_deterministic_kernel_median = statistics.median(
|
|
latencies_deterministic_kernel
|
|
)
|
|
|
|
# Calculate latency statistics
|
|
lat_ar_median = statistics.median(latencies_ar)
|
|
lat_rs_ag_median = statistics.median(latencies_rs_ag)
|
|
overhead_rs_ag = ((lat_rs_ag_median - lat_ar_median) / lat_ar_median) * 100
|
|
|
|
# Calculate variance statistics
|
|
ar_max_variance = max(ar_variance) if ar_variance else 0.0
|
|
rs_ag_max_variance = max(rs_ag_variance) if rs_ag_variance else 0.0
|
|
|
|
results[bs] = {
|
|
"all_reduce": {
|
|
"latency_median": lat_ar_median,
|
|
"deterministic": ar_deterministic,
|
|
"max_variance": ar_max_variance,
|
|
},
|
|
"rs_ag": {
|
|
"latency_median": lat_rs_ag_median,
|
|
"deterministic": rs_ag_deterministic,
|
|
"max_variance": rs_ag_max_variance,
|
|
},
|
|
"custom_ar": (
|
|
{
|
|
"latency_median": lat_custom_ar_median,
|
|
"deterministic": custom_ar_deterministic,
|
|
"max_variance": custom_ar_max_variance,
|
|
}
|
|
if custom_ar is not None
|
|
else None
|
|
),
|
|
"deterministic_kernel": (
|
|
{
|
|
"latency_median": lat_deterministic_kernel_median,
|
|
"deterministic": deterministic_kernel_deterministic,
|
|
"max_variance": deterministic_kernel_max_variance,
|
|
}
|
|
if lat_deterministic_kernel_median is not None
|
|
else None
|
|
),
|
|
"optimized_rs_ag": (
|
|
{
|
|
"latency_median": lat_optimized_rs_ag_median,
|
|
"deterministic": optimized_rs_ag_deterministic,
|
|
"max_variance": optimized_rs_ag_max_variance,
|
|
}
|
|
if lat_optimized_rs_ag_median is not None
|
|
else None
|
|
),
|
|
"overhead_rs_ag_pct": overhead_rs_ag,
|
|
}
|
|
|
|
print(
|
|
f" All-Reduce: {lat_ar_median*1000:.3f}ms, Deterministic: {ar_deterministic}, Max variance: {ar_max_variance:.6f}"
|
|
)
|
|
print(
|
|
f" RS+All-Gather: {lat_rs_ag_median*1000:.3f}ms, Deterministic: {rs_ag_deterministic}, Max variance: {rs_ag_max_variance:.6f}"
|
|
)
|
|
if custom_ar is not None and lat_custom_ar_median is not None:
|
|
overhead_custom = (
|
|
(lat_custom_ar_median - lat_ar_median) / lat_ar_median
|
|
) * 100
|
|
print(
|
|
f" Custom AR: {lat_custom_ar_median*1000:.3f}ms, Deterministic: {custom_ar_deterministic}, Max variance: {custom_ar_max_variance:.6f}, Overhead: {overhead_custom:+.1f}%"
|
|
)
|
|
if lat_deterministic_kernel_median is not None:
|
|
overhead_kernel = (
|
|
(lat_deterministic_kernel_median - lat_ar_median) / lat_ar_median
|
|
) * 100
|
|
speedup_kernel_vs_rs_ag = (
|
|
(lat_rs_ag_median - lat_deterministic_kernel_median)
|
|
/ lat_rs_ag_median
|
|
) * 100
|
|
print(
|
|
f" Deterministic Kernel: {lat_deterministic_kernel_median*1000:.3f}ms, Deterministic: {deterministic_kernel_deterministic}, Max variance: {deterministic_kernel_max_variance:.6f}, Overhead: {overhead_kernel:+.1f}%, Speedup vs RS+AG: {speedup_kernel_vs_rs_ag:+.1f}%"
|
|
)
|
|
if lat_optimized_rs_ag_median is not None:
|
|
overhead_opt = (
|
|
(lat_optimized_rs_ag_median - lat_ar_median) / lat_ar_median
|
|
) * 100
|
|
speedup_vs_rs_ag = (
|
|
(lat_rs_ag_median - lat_optimized_rs_ag_median) / lat_rs_ag_median
|
|
) * 100
|
|
print(
|
|
f" Optimized RS+AG: {lat_optimized_rs_ag_median*1000:.3f}ms, Deterministic: {optimized_rs_ag_deterministic}, Max variance: {optimized_rs_ag_max_variance:.6f}, Overhead: {overhead_opt:+.1f}%, Speedup vs RS+AG: {speedup_vs_rs_ag:+.1f}%"
|
|
)
|
|
print(f" RS+AG Overhead: {overhead_rs_ag:+.1f}%")
|
|
|
|
if rank == 0:
|
|
results_queue.put(results)
|
|
|
|
dist.destroy_process_group()
|
|
|
|
|
|
def main():
|
|
world_size = 8
|
|
available_gpus = torch.cuda.device_count()
|
|
|
|
print("=" * 80)
|
|
print("All-Reduce vs Reduce-Scatter + All-Gather Determinism & Latency Benchmark")
|
|
print("=" * 80)
|
|
print(f"Available GPUs: {available_gpus}")
|
|
print(f"Using world_size: {world_size}")
|
|
print(f"Hidden dimension: 16384")
|
|
print(f"Tensor dtype: bfloat16")
|
|
print(f"Trials per batch size: 10 (testing determinism)")
|
|
print(f"Testing batch sizes: [1, 4, 8, 16, 32, 64, 128, 256, 512]")
|
|
print("=" * 80)
|
|
|
|
if available_gpus < world_size:
|
|
print(
|
|
f"WARNING: Only {available_gpus} GPUs available, using {available_gpus} instead"
|
|
)
|
|
world_size = available_gpus
|
|
|
|
if world_size < 2:
|
|
print("ERROR: Need at least 2 GPUs for this benchmark")
|
|
return
|
|
|
|
mp.set_start_method("spawn", force=True)
|
|
port = get_open_port()
|
|
|
|
results_queue = mp.Queue()
|
|
procs = []
|
|
for rank in range(world_size):
|
|
p = mp.Process(target=worker, args=(world_size, rank, port, results_queue))
|
|
p.start()
|
|
procs.append(p)
|
|
|
|
for p in procs:
|
|
p.join()
|
|
|
|
# Collect results
|
|
if not results_queue.empty():
|
|
results = results_queue.get()
|
|
|
|
print("\n" + "=" * 80)
|
|
print("SUMMARY")
|
|
print("=" * 80)
|
|
header = f"{'Batch':<8} {'AR (ms)':<12} {'AR Det':<8} {'RS+AG (ms)':<15} {'RS+AG Det':<10} {'RS+AG Ovh':<12}"
|
|
if any(r.get("custom_ar") is not None for r in results.values()):
|
|
header += (
|
|
f" {'Custom AR (ms)':<18} {'Custom AR Det':<15} {'Custom AR Ovh':<15}"
|
|
)
|
|
if any(r.get("deterministic_kernel") is not None for r in results.values()):
|
|
header += f" {'Det Kernel (ms)':<18} {'Det Kernel Det':<15} {'Det Kernel Ovh':<15} {'Speedup':<10}"
|
|
if any(r.get("optimized_rs_ag") is not None for r in results.values()):
|
|
header += f" {'Opt RS+AG (ms)':<18} {'Opt RS+AG Det':<15} {'Opt RS+AG Ovh':<15} {'Speedup':<10}"
|
|
print(header)
|
|
print("-" * 150)
|
|
|
|
for bs in sorted(results.keys()):
|
|
r = results[bs]
|
|
ar_det_str = "✓" if r["all_reduce"]["deterministic"] else "✗"
|
|
rs_ag_det_str = "✓" if r["rs_ag"]["deterministic"] else "✗"
|
|
line = (
|
|
f"{bs:<8} {r['all_reduce']['latency_median']*1000:<12.3f} {ar_det_str:<8} "
|
|
f"{r['rs_ag']['latency_median']*1000:<15.3f} {rs_ag_det_str:<10} "
|
|
f"{r['overhead_rs_ag_pct']:<12.1f}"
|
|
)
|
|
if r.get("custom_ar") is not None:
|
|
custom_ar = r["custom_ar"]
|
|
custom_ar_det_str = "✓" if custom_ar["deterministic"] else "✗"
|
|
custom_ar_overhead = (
|
|
(custom_ar["latency_median"] - r["all_reduce"]["latency_median"])
|
|
/ r["all_reduce"]["latency_median"]
|
|
) * 100
|
|
line += f" {custom_ar['latency_median']*1000:<18.3f} {custom_ar_det_str:<15} {custom_ar_overhead:<15.1f}"
|
|
if r.get("deterministic_kernel") is not None:
|
|
det_kernel = r["deterministic_kernel"]
|
|
det_kernel_det_str = "✓" if det_kernel["deterministic"] else "✗"
|
|
det_kernel_overhead = (
|
|
(det_kernel["latency_median"] - r["all_reduce"]["latency_median"])
|
|
/ r["all_reduce"]["latency_median"]
|
|
) * 100
|
|
speedup_kernel = (
|
|
(r["rs_ag"]["latency_median"] - det_kernel["latency_median"])
|
|
/ r["rs_ag"]["latency_median"]
|
|
) * 100
|
|
line += f" {det_kernel['latency_median']*1000:<18.3f} {det_kernel_det_str:<15} {det_kernel_overhead:<15.1f} {speedup_kernel:<10.1f}"
|
|
if r.get("optimized_rs_ag") is not None:
|
|
opt_rs_ag = r["optimized_rs_ag"]
|
|
opt_rs_ag_det_str = "✓" if opt_rs_ag["deterministic"] else "✗"
|
|
opt_rs_ag_overhead = (
|
|
(opt_rs_ag["latency_median"] - r["all_reduce"]["latency_median"])
|
|
/ r["all_reduce"]["latency_median"]
|
|
) * 100
|
|
speedup = (
|
|
(r["rs_ag"]["latency_median"] - opt_rs_ag["latency_median"])
|
|
/ r["rs_ag"]["latency_median"]
|
|
) * 100
|
|
line += f" {opt_rs_ag['latency_median']*1000:<18.3f} {opt_rs_ag_det_str:<15} {opt_rs_ag_overhead:<15.1f} {speedup:<10.1f}"
|
|
print(line)
|
|
|
|
print("=" * 80)
|
|
|
|
# Calculate statistics
|
|
overheads_rs_ag = [r["overhead_rs_ag_pct"] for r in results.values()]
|
|
ar_deterministic_count = sum(
|
|
1 for r in results.values() if r["all_reduce"]["deterministic"]
|
|
)
|
|
rs_ag_deterministic_count = sum(
|
|
1 for r in results.values() if r["rs_ag"]["deterministic"]
|
|
)
|
|
custom_ar_deterministic_count = sum(
|
|
1
|
|
for r in results.values()
|
|
if r.get("custom_ar") and r["custom_ar"]["deterministic"]
|
|
)
|
|
custom_ar_total_count = sum(
|
|
1 for r in results.values() if r.get("custom_ar") is not None
|
|
)
|
|
|
|
deterministic_kernel_deterministic_count = sum(
|
|
1
|
|
for r in results.values()
|
|
if r.get("deterministic_kernel")
|
|
and r["deterministic_kernel"]["deterministic"]
|
|
)
|
|
deterministic_kernel_total_count = sum(
|
|
1 for r in results.values() if r.get("deterministic_kernel") is not None
|
|
)
|
|
|
|
print(f"\nDeterminism Summary:")
|
|
print(
|
|
f" All-Reduce deterministic: {ar_deterministic_count}/{len(results)} batch sizes"
|
|
)
|
|
print(
|
|
f" RS+All-Gather deterministic: {rs_ag_deterministic_count}/{len(results)} batch sizes"
|
|
)
|
|
if custom_ar_total_count > 0:
|
|
print(
|
|
f" Custom AR deterministic: {custom_ar_deterministic_count}/{custom_ar_total_count} batch sizes"
|
|
)
|
|
if deterministic_kernel_total_count > 0:
|
|
print(
|
|
f" Deterministic Kernel deterministic: {deterministic_kernel_deterministic_count}/{deterministic_kernel_total_count} batch sizes"
|
|
)
|
|
|
|
print(f"\nLatency Overhead Statistics (RS+AG vs All-Reduce):")
|
|
avg_overhead = statistics.mean(overheads_rs_ag)
|
|
median_overhead = statistics.median(overheads_rs_ag)
|
|
min_overhead = min(overheads_rs_ag)
|
|
max_overhead = max(overheads_rs_ag)
|
|
print(f" Average: {avg_overhead:.1f}%")
|
|
print(f" Median: {median_overhead:.1f}%")
|
|
print(f" Min: {min_overhead:.1f}%")
|
|
print(f" Max: {max_overhead:.1f}%")
|
|
|
|
if custom_ar_total_count > 0:
|
|
overheads_custom = []
|
|
for r in results.values():
|
|
if r.get("custom_ar") is not None:
|
|
overhead = (
|
|
(
|
|
r["custom_ar"]["latency_median"]
|
|
- r["all_reduce"]["latency_median"]
|
|
)
|
|
/ r["all_reduce"]["latency_median"]
|
|
) * 100
|
|
overheads_custom.append(overhead)
|
|
print(f"\nLatency Overhead Statistics (Custom AR vs All-Reduce):")
|
|
print(f" Average: {statistics.mean(overheads_custom):.1f}%")
|
|
print(f" Median: {statistics.median(overheads_custom):.1f}%")
|
|
print(f" Min: {min(overheads_custom):.1f}%")
|
|
print(f" Max: {max(overheads_custom):.1f}%")
|
|
|
|
if deterministic_kernel_total_count > 0:
|
|
overheads_kernel = []
|
|
speedups_kernel = []
|
|
for r in results.values():
|
|
if r.get("deterministic_kernel") is not None:
|
|
overhead = (
|
|
(
|
|
r["deterministic_kernel"]["latency_median"]
|
|
- r["all_reduce"]["latency_median"]
|
|
)
|
|
/ r["all_reduce"]["latency_median"]
|
|
) * 100
|
|
overheads_kernel.append(overhead)
|
|
speedup = (
|
|
(
|
|
r["rs_ag"]["latency_median"]
|
|
- r["deterministic_kernel"]["latency_median"]
|
|
)
|
|
/ r["rs_ag"]["latency_median"]
|
|
) * 100
|
|
speedups_kernel.append(speedup)
|
|
print(
|
|
f"\nLatency Overhead Statistics (Deterministic Kernel vs All-Reduce):"
|
|
)
|
|
print(f" Average: {statistics.mean(overheads_kernel):.1f}%")
|
|
print(f" Median: {statistics.median(overheads_kernel):.1f}%")
|
|
print(f" Min: {min(overheads_kernel):.1f}%")
|
|
print(f" Max: {max(overheads_kernel):.1f}%")
|
|
print(f"\nSpeedup Statistics (Deterministic Kernel vs RS+AG):")
|
|
print(f" Average: {statistics.mean(speedups_kernel):.1f}%")
|
|
print(f" Median: {statistics.median(speedups_kernel):.1f}%")
|
|
print(f" Min: {min(speedups_kernel):.1f}%")
|
|
print(f" Max: {max(speedups_kernel):.1f}%")
|
|
|
|
# Show variance for non-deterministic cases
|
|
print(f"\nVariance Analysis (non-deterministic cases):")
|
|
for bs in sorted(results.keys()):
|
|
r = results[bs]
|
|
if not r["all_reduce"]["deterministic"]:
|
|
print(
|
|
f" Batch {bs}: All-Reduce max variance: {r['all_reduce']['max_variance']:.6f}"
|
|
)
|
|
if not r["rs_ag"]["deterministic"]:
|
|
print(
|
|
f" Batch {bs}: RS+All-Gather max variance: {r['rs_ag']['max_variance']:.6f}"
|
|
)
|
|
if r.get("custom_ar") is not None and not r["custom_ar"]["deterministic"]:
|
|
print(
|
|
f" Batch {bs}: Custom AR max variance: {r['custom_ar']['max_variance']:.6f}"
|
|
)
|
|
if (
|
|
r.get("deterministic_kernel") is not None
|
|
and not r["deterministic_kernel"]["deterministic"]
|
|
):
|
|
print(
|
|
f" Batch {bs}: Deterministic Kernel max variance: {r['deterministic_kernel']['max_variance']:.6f}"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|