forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 3
/
te_utils.py
41 lines (32 loc) · 1.27 KB
/
te_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import torch
class ExecutionCounter(object):
def try_get_trigger_value(self):
try:
return torch._C._jit_get_trigger_value(self.name)
except Exception:
return 0
def __init__(self, name):
self.name = name
self.start_value = self.try_get_trigger_value()
def elapsed_value(self):
value = self.try_get_trigger_value()
return value - self.start_value
class CudaCodeGenCreated(ExecutionCounter):
def __init__(self):
super(CudaCodeGenCreated, self).__init__("cuda_codegen_created")
class CudaCodeGenExecuted(ExecutionCounter):
def __init__(self):
super(CudaCodeGenExecuted, self).__init__("cuda_codegen_executed")
class LLVMCodeGenCreated(ExecutionCounter):
def __init__(self):
super(LLVMCodeGenCreated, self).__init__("llvm_codegen_created")
class LLVMCodeGenExecuted(ExecutionCounter):
def __init__(self):
super(LLVMCodeGenExecuted, self).__init__("llvm_codegen_executed")
class SimpleIREvalExecuted(ExecutionCounter):
def __init__(self):
super(SimpleIREvalExecuted, self).__init__("simple_ir_eval_executed")