compute_roofline: argparse --trace, fix stale default path (D4)

The hardcoded traces/sampled_1000req_seed42.jsonl no longer exists; switch
the default to the current sampled trace file w600_r0.0015_st30.jsonl and
let users override via --trace. Skip Part 4 cleanly when the file is
missing instead of relying on os.path.exists.
This commit is contained in:
2026-05-23 20:58:09 +08:00
parent 547611e022
commit ea5c3bfe6b

View File

@@ -12,6 +12,7 @@ GPU: NVIDIA H20
- Roofline ridge point: 148/4.0 = 37 FLOP/byte - Roofline ridge point: 148/4.0 = 37 FLOP/byte
""" """
import argparse
import json import json
import math import math
@@ -161,16 +162,25 @@ print(" PART 4: Agentic Workload Real Distribution")
print("-" * 80) print("-" * 80)
# Use actual trace data # Use actual trace data
import os _parser = argparse.ArgumentParser(description=__doc__)
trace_path = "traces/sampled_1000req_seed42.jsonl" _parser.add_argument("--trace", type=str,
if os.path.exists(trace_path): default="traces/w600_r0.0015_st30.jsonl",
help="Sampled trace JSONL for empirical workload roofline (Part 4)")
_args, _ = _parser.parse_known_args()
trace_path = _args.trace
try:
_trace_fh = open(trace_path)
except FileNotFoundError:
print(f" (skipped: trace file not found: {trace_path})")
_trace_fh = None
if _trace_fh is not None:
BLOCK_SIZE = 512 BLOCK_SIZE = 512
seen = set() seen = set()
compute_bound = 0 compute_bound = 0
memory_bound = 0 memory_bound = 0
total = 0 total = 0
for line in open(trace_path): for line in _trace_fh:
d = json.loads(line) d = json.loads(line)
seq_len = d["input_length"] seq_len = d["input_length"]
if seq_len < 1: continue if seq_len < 1: continue
@@ -201,6 +211,8 @@ if os.path.exists(trace_path):
else: else:
memory_bound += 1 memory_bound += 1
_trace_fh.close()
if total > 0:
print(f" With actual trace prefix cache pattern:") print(f" With actual trace prefix cache pattern:")
print(f" Compute-bound prefills: {compute_bound} ({compute_bound*100//total}%)") print(f" Compute-bound prefills: {compute_bound} ({compute_bound*100//total}%)")
print(f" Memory-bound prefills: {memory_bound} ({memory_bound*100//total}%)") print(f" Memory-bound prefills: {memory_bound} ({memory_bound*100//total}%)")