The hardcoded traces/sampled_1000req_seed42.jsonl no longer exists; switch the default to the current sampled trace file w600_r0.0015_st30.jsonl and let users override via --trace. Skip Part 4 cleanly when the file is missing instead of relying on os.path.exists.
223 lines
7.9 KiB
Python
223 lines
7.9 KiB
Python
"""Roofline analysis: compute/memory ratio for prefill vs decode
|
|
under different sequence lengths and KV cache reuse ratios.
|
|
|
|
Model: Qwen3-Coder-30B-A3B (MoE)
|
|
- 48 layers, hidden=2048, heads=32, kv_heads=4, head_dim=128
|
|
- MoE: 128 experts, top-8 active, intermediate=6144
|
|
- Total params: ~30B, Active params per token: ~3B
|
|
|
|
GPU: NVIDIA H20
|
|
- BF16 peak: 148 TFLOPS
|
|
- HBM bandwidth: 4.0 TB/s
|
|
- Roofline ridge point: 148/4.0 = 37 FLOP/byte
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import math
|
|
|
|
# ===== Model config =====
|
|
L = 48 # layers
|
|
D = 2048 # hidden dim
|
|
H = 32 # attention heads
|
|
H_kv = 4 # KV heads (GQA)
|
|
D_head = 128 # head dim
|
|
D_ffn = 6144 # FFN intermediate (per expert)
|
|
N_experts = 128 # total experts
|
|
K_experts = 8 # active experts per token
|
|
VOCAB = 151936
|
|
BYTES = 2 # BF16
|
|
|
|
# ===== GPU config (H20) =====
|
|
PEAK_FLOPS = 148e12 # BF16 TFLOPS
|
|
HBM_BW = 4.0e12 # bytes/s
|
|
RIDGE_POINT = PEAK_FLOPS / HBM_BW # ~37 FLOP/byte
|
|
|
|
print("=" * 80)
|
|
print(" ROOFLINE ANALYSIS: Prefill vs Decode under KV Cache Reuse")
|
|
print(" Model: Qwen3-Coder-30B-A3B (MoE 128E top-8) | GPU: H20")
|
|
print("=" * 80)
|
|
print(f" Ridge point: {RIDGE_POINT:.1f} FLOP/byte")
|
|
print(f" Above ridge → compute-bound | Below ridge → memory-bound")
|
|
|
|
# ===== Per-token compute & memory for each component =====
|
|
|
|
def attention_prefill_flops(seq_len, new_tokens):
|
|
"""FLOPs for attention on new_tokens with seq_len context."""
|
|
# QKV projection: new_tokens * D * (D + 2*D_kv) * 2
|
|
d_kv = H_kv * D_head
|
|
qkv_flops = new_tokens * (D * D * 2 + D * d_kv * 2 * 2) # Q + K + V
|
|
# Attention score: new_tokens * seq_len * D * 2 (Q@K^T + softmax@V)
|
|
attn_flops = new_tokens * seq_len * D * 2 * 2 # simplified: 2 matmuls
|
|
# Output projection: new_tokens * D * D * 2
|
|
out_flops = new_tokens * D * D * 2
|
|
return (qkv_flops + attn_flops + out_flops) * L
|
|
|
|
def attention_prefill_bytes(seq_len, new_tokens, cached_tokens):
|
|
"""Memory access for attention prefill."""
|
|
d_kv = H_kv * D_head
|
|
# Load model weights (QKV + O projections): D*(D+2*d_kv+D) * BYTES * L
|
|
weight_bytes = D * (D + 2 * d_kv + D) * BYTES * L
|
|
# Load cached KV: cached_tokens * 2 * d_kv * BYTES * L
|
|
cached_kv_bytes = cached_tokens * 2 * d_kv * BYTES * L
|
|
# Read input activations + write output: new_tokens * D * BYTES * 2 * L
|
|
act_bytes = new_tokens * D * BYTES * 2 * L
|
|
# Write new KV to cache: new_tokens * 2 * d_kv * BYTES * L
|
|
new_kv_bytes = new_tokens * 2 * d_kv * BYTES * L
|
|
return weight_bytes + cached_kv_bytes + act_bytes + new_kv_bytes
|
|
|
|
def ffn_flops(n_tokens):
|
|
"""FLOPs for MoE FFN on n_tokens."""
|
|
# Per expert: 3 * n_tokens * D * D_ffn * 2 (gate + up + down)
|
|
# Active experts: K_experts
|
|
return 3 * n_tokens * D * D_ffn * 2 * K_experts * L
|
|
|
|
def ffn_bytes(n_tokens):
|
|
"""Memory access for MoE FFN."""
|
|
# Load K_experts worth of weights per layer: K * 3 * D * D_ffn * BYTES
|
|
weight_bytes = K_experts * 3 * D * D_ffn * BYTES * L
|
|
# Activations: n_tokens * D * BYTES * 2 * L
|
|
act_bytes = n_tokens * D * BYTES * 2 * L
|
|
return weight_bytes + act_bytes
|
|
|
|
def decode_flops(seq_len):
|
|
"""FLOPs for 1 decode token."""
|
|
return attention_prefill_flops(seq_len, 1) + ffn_flops(1)
|
|
|
|
def decode_bytes(seq_len):
|
|
"""Memory bytes for 1 decode token."""
|
|
return attention_prefill_bytes(seq_len, 1, seq_len) + ffn_bytes(1)
|
|
|
|
# ===== Analysis =====
|
|
|
|
print("\n" + "-" * 80)
|
|
print(" PART 1: Decode Roofline (baseline)")
|
|
print("-" * 80)
|
|
print(f" {'SeqLen':>8} {'FLOP':>14} {'Bytes':>14} {'AI (F/B)':>10} {'Bound':>12}")
|
|
|
|
for seq_len in [1000, 4000, 8000, 16000, 32000, 64000, 128000]:
|
|
flops = decode_flops(seq_len)
|
|
bytes_ = decode_bytes(seq_len)
|
|
ai = flops / bytes_
|
|
bound = "COMPUTE" if ai > RIDGE_POINT else "MEMORY"
|
|
print(f" {seq_len:>8,} {flops:>14.2e} {bytes_:>14.2e} {ai:>10.1f} {bound:>12}")
|
|
|
|
print("\n" + "-" * 80)
|
|
print(" PART 2: Prefill with KV Cache Reuse")
|
|
print(" (Total input = seq_len, cached = seq_len * reuse_ratio, new = rest)")
|
|
print("-" * 80)
|
|
print(f" {'SeqLen':>8} {'Reuse%':>7} {'NewTok':>8} {'FLOP':>14} {'Bytes':>14} {'AI (F/B)':>10} {'Bound':>12} {'vs Decode':>10}")
|
|
|
|
for seq_len in [4000, 16000, 32000, 64000, 128000]:
|
|
for reuse in [0.0, 0.3, 0.5, 0.7, 0.9, 0.95]:
|
|
cached = int(seq_len * reuse)
|
|
new = seq_len - cached
|
|
|
|
# Attention: compute on new tokens, but read cached KV for context
|
|
attn_f = attention_prefill_flops(seq_len, new)
|
|
attn_b = attention_prefill_bytes(seq_len, new, cached)
|
|
|
|
# FFN: only on new tokens
|
|
ffn_f = ffn_flops(new)
|
|
ffn_b = ffn_bytes(new)
|
|
|
|
total_f = attn_f + ffn_f
|
|
total_b = attn_b + ffn_b
|
|
ai = total_f / total_b if total_b > 0 else 0
|
|
|
|
# Compare with decode at same seq_len
|
|
dec_f = decode_flops(seq_len)
|
|
dec_b = decode_bytes(seq_len)
|
|
dec_ai = dec_f / dec_b
|
|
|
|
bound = "COMPUTE" if ai > RIDGE_POINT else "MEMORY"
|
|
ratio = f"{ai/dec_ai:.1f}x" if dec_ai > 0 else "N/A"
|
|
|
|
print(f" {seq_len:>8,} {reuse*100:>6.0f}% {new:>8,} {total_f:>14.2e} {total_b:>14.2e} {ai:>10.1f} {bound:>12} {ratio:>10}")
|
|
print()
|
|
|
|
print("-" * 80)
|
|
print(" PART 3: Key Thresholds")
|
|
print("-" * 80)
|
|
|
|
# At what reuse ratio does prefill become memory-bound?
|
|
for seq_len in [4000, 16000, 32000, 64000, 128000]:
|
|
for reuse_pct in range(0, 100):
|
|
reuse = reuse_pct / 100.0
|
|
cached = int(seq_len * reuse)
|
|
new = seq_len - cached
|
|
if new < 1: continue
|
|
attn_f = attention_prefill_flops(seq_len, new)
|
|
attn_b = attention_prefill_bytes(seq_len, new, cached)
|
|
ffn_f = ffn_flops(new)
|
|
ffn_b = ffn_bytes(new)
|
|
ai = (attn_f + ffn_f) / (attn_b + ffn_b)
|
|
if ai < RIDGE_POINT:
|
|
print(f" SeqLen={seq_len:>6,}: prefill becomes memory-bound at {reuse_pct}% reuse (AI={ai:.1f})")
|
|
break
|
|
|
|
print()
|
|
print("-" * 80)
|
|
print(" PART 4: Agentic Workload Real Distribution")
|
|
print("-" * 80)
|
|
|
|
# Use actual trace data
|
|
_parser = argparse.ArgumentParser(description=__doc__)
|
|
_parser.add_argument("--trace", type=str,
|
|
default="traces/w600_r0.0015_st30.jsonl",
|
|
help="Sampled trace JSONL for empirical workload roofline (Part 4)")
|
|
_args, _ = _parser.parse_known_args()
|
|
trace_path = _args.trace
|
|
try:
|
|
_trace_fh = open(trace_path)
|
|
except FileNotFoundError:
|
|
print(f" (skipped: trace file not found: {trace_path})")
|
|
_trace_fh = None
|
|
if _trace_fh is not None:
|
|
BLOCK_SIZE = 512
|
|
seen = set()
|
|
compute_bound = 0
|
|
memory_bound = 0
|
|
total = 0
|
|
|
|
for line in _trace_fh:
|
|
d = json.loads(line)
|
|
seq_len = d["input_length"]
|
|
if seq_len < 1: continue
|
|
hids = d.get("hash_ids", [])
|
|
|
|
cached_blocks = 0
|
|
for hid in hids:
|
|
if hid in seen:
|
|
cached_blocks += 1
|
|
else:
|
|
break
|
|
for hid in hids:
|
|
seen.add(hid)
|
|
|
|
cached = cached_blocks * BLOCK_SIZE
|
|
new = max(1, seq_len - cached)
|
|
reuse = cached / seq_len
|
|
|
|
attn_f = attention_prefill_flops(seq_len, new)
|
|
attn_b = attention_prefill_bytes(seq_len, new, cached)
|
|
ffn_f = ffn_flops(new)
|
|
ffn_b = ffn_bytes(new)
|
|
ai = (attn_f + ffn_f) / (attn_b + ffn_b)
|
|
|
|
total += 1
|
|
if ai > RIDGE_POINT:
|
|
compute_bound += 1
|
|
else:
|
|
memory_bound += 1
|
|
|
|
_trace_fh.close()
|
|
if total > 0:
|
|
print(f" With actual trace prefix cache pattern:")
|
|
print(f" Compute-bound prefills: {compute_bound} ({compute_bound*100//total}%)")
|
|
print(f" Memory-bound prefills: {memory_bound} ({memory_bound*100//total}%)")
|
|
print(f" (Decode is ALWAYS memory-bound at these seq lengths)")
|
|
print()
|
|
print(f" Implication: {memory_bound*100//total}% of agentic prefills behave like decode")
|
|
print(f" → PD separation treats them as 'compute-heavy' but they are actually memory-heavy")
|