-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathgenerate_sparse_code.py
123 lines (108 loc) · 5.36 KB
/
generate_sparse_code.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
INDENTATION = ' '
#get function name
def get_fuc_name(fuc_var):
result = fuc_var[0].split(" ")
return result[-1]
#remove empty strings from a list of strings
def remove_empty_string(string_list):
return [item for item in string_list if item != '']
#remove unnecessary characters (newline and closing parenthesis) from string
def remove_unnecessary_chars(string):
for unnecessary_char in ['\n', ')']:
if unnecessary_char in string:
string = string.replace(unnecessary_char, '')
return string
def cal_array_class(array_dim_list, i):
each_element = array_dim_list[i]
for i in range(3):
if str(i+1) in each_element[0]:
return str(i+1)
return "10000"
#get a list of arguments for the function
def get_arguments(var_list, array_dim_list, output_index_list, array_index_list):
class_choice = ["graph", "array", "op", "reverse"]
output_list = []
for (i, var_list_item) in enumerate(var_list):
if i in array_index_list:
temp1 = [1, var_list_item]
array_class = cal_array_class(array_dim_list, i)
temp1.append(int(array_class))
output_list.append(temp1)
elif i in output_index_list:
temp1 = [4, var_list_item]
array_class = cal_array_class(array_dim_list, i)
temp1.append(int(array_class))
output_list.append(temp1)
else:
number_dict = {"graph": 0, "op": 2, "reverse": 3, "norm": 5}
for key in number_dict:
if key in var_list_item:
output_list.append([number_dict[key], key])
break
return output_list
#get the arguments in a string
def make_arguments(output_list, string_dict):
num_of_dlpack_index = []
num_of_dlpack_name = []
write_string = ""
for (j, item) in enumerate(output_list):
if item[0] in string_dict:
write_string += f'{string_dict[item[0]]}, '
elif item[0] == 1:
num_of_dlpack_index.append(j)
num_of_dlpack_name.append(item[1])
write_string += f'{item[1]}, '
elif (item[0] == 4) and (item[2] in range(1, 4)):
id = item[1].replace("output", "")
write_string += ', '.join(f'dim{id}_{i}' for i in range(item[2])) + ', '
write_string += "device0" #remove final comma/space and add ender
return write_string
#get function information
def fuc_var_class(fuc_name):
arguments = fuc_name[1].split(",")
var_list = [remove_unnecessary_chars(argument.split(" ")[-1]) for argument in arguments]
array_dim_list = [remove_empty_string(argument.split(" ")[:-1]) for argument in arguments]
array_index_list = [i for (i, item) in enumerate(array_dim_list) if 'array' in item[0]]
output_index_list = [i for i in array_index_list if 'output' in var_list[i]]
array_index_list = [i for i in array_index_list if i not in output_index_list]
return var_list, array_dim_list, array_index_list, output_index_list
#generate the base function that calls the real function
def generate_base_function(function_name, output_list, string_dict):
args = make_arguments(output_list, string_dict)
inputs = [arg for arg in args.split(', ') if 'input' in arg]
write_string = f'def {function_name}({args}):\n'
write_string += f'{INDENTATION}@tf.custom_gradient\n'
write_string += f'{INDENTATION}def _lambda({", ".join(arg.replace("input", "X") for arg in inputs)}):\n'
write_string += f'{INDENTATION*2}return {function_name}_real({args.replace("input", "X")})\n'
write_string += f'{INDENTATION}return _lambda({", ".join(inputs)})\n\n'
return write_string
#generate the real function called by the base function
def generate_real_function(function_name, output_list, string_dict):
args = make_arguments(output_list, string_dict)
inputs = [arg for arg in args.split(', ') if 'input' in arg]
write_string = f'def {function_name}_real({args}):\n'
write_string += f'{INDENTATION}out = gp_apis.gp_{function_name}({args.replace("reverse", "1")})\n'
write_string += f'{INDENTATION}def grad({", ".join(arg.replace("input", "dZ") for arg in inputs)}):\n'
write_string += f'{INDENTATION*2}return gp_apis.gp_{function_name}({args.replace("reverse", "0").replace("input", "dZ")})\n'
write_string += f'{INDENTATION}return out, grad\n\n'
return write_string
#generate overall code
def generate_code(line_string):
string_dict = {0: 'graph', 2: 'op', 3: 'reverse', 5: 'norm'}
string_sep = line_string.split("{")
fuc_var = string_sep[0].split("(")
function_name = get_fuc_name(fuc_var)
var_list, array_dim_list, array_index_list, output_index_list = fuc_var_class(fuc_var)
output_list = get_arguments(var_list, array_dim_list, output_index_list, array_index_list) #get func arguments
write_string = generate_base_function(function_name, output_list, string_dict)
write_string += generate_real_function(function_name, output_list, string_dict)
return write_string
#overall function to generate full file
def generate_sparse_file(input_file, output_file):
write_string = ('import tensorflow as tf' '\n'
'import gp_apis' '\n\n')
with open(input_file, 'r') as file:
lines = file.readlines()
write_string += ''.join(generate_code(line) for line in lines)
with open(output_file, 'w') as file:
file.write(write_string)