33 lines
767 B
Python
33 lines
767 B
Python
from __future__ import annotations
|
|
|
|
import argparse
|
|
import subprocess
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
|
|
|
|
TASK_TO_SCRIPT = {
|
|
"vector_add": ROOT / "bench" / "bench_vector_add.py",
|
|
"softmax": ROOT / "bench" / "bench_softmax.py",
|
|
"matmul": ROOT / "bench" / "bench_matmul.py",
|
|
"attention": ROOT / "bench" / "bench_attention.py",
|
|
}
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--task", choices=sorted(TASK_TO_SCRIPT), required=True)
|
|
parser.add_argument("extra_args", nargs="*")
|
|
args = parser.parse_args()
|
|
|
|
cmd = [sys.executable, str(TASK_TO_SCRIPT[args.task]), *args.extra_args]
|
|
subprocess.run(cmd, check=True)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|