-
Notifications
You must be signed in to change notification settings - Fork 288
/
run_benchmark.py
48 lines (39 loc) · 1.32 KB
/
run_benchmark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import argparse
import importlib
import os
import traceback
from pathlib import Path
from typing import Dict
CURRENT_DIR = os.path.dirname(os.path.realpath(__file__))
def list_benchmarks() -> Dict[str, str]:
benchmarks = {}
import userbenchmark
bdir = Path(userbenchmark.__file__).parent.resolve()
fb_bdir = bdir.joinpath("fb")
if fb_bdir.exists():
for fb_bm in filter(lambda x: x.is_dir(), fb_bdir.iterdir()):
benchmarks[fb_bm.name] = f"fb.{fb_bm.name}"
for bm in filter(lambda x: x.is_dir() and not x.name == "fb", bdir.iterdir()):
benchmarks[bm.name] = bm.name
return benchmarks
def run():
available_benchmarks = list_benchmarks()
parser = argparse.ArgumentParser(
description="Run a TorchBench user benchmark", add_help=False
)
parser.add_argument(
"bm_name",
choices=available_benchmarks.keys(),
help="name of the user benchmark",
)
args, bm_args = parser.parse_known_args()
try:
benchmark = importlib.import_module(
f"userbenchmark.{available_benchmarks[args.bm_name]}.run"
)
benchmark.run(bm_args)
except ImportError as e:
print(f"Failed to import user benchmark module {args.bm_name}, error: {str(e)}")
traceback.print_exc()
if __name__ == "__main__":
run()