-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathutil.py
43 lines (31 loc) · 1.4 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import numpy as np
from csv import reader, writer
from sys import argv
number_of_inputs = 5
def read_data(file_name):
with open(file_name, 'r') as file:
data = np.array(list(reader(file, delimiter = ';'))).astype(np.float32)
np.random.shuffle(data)
return data
# filtering similar observated data (the same number of hits and misses is needed bcs of learning - learn the same number when you hit or lose to avoid overtrain)
def filter_data(file_name):
data = read_data(f'{file_name}.csv')
hits = _filter(data[np.where(data[:, -1] == 1)[0]], 0.001)
misses = _filter(data[np.where(data[:, -1] == 0)[0]], 0.2)
# make sure theres even number of hits as misses
misses = misses[:hits.shape[0]]
hits = hits[:misses.shape[0]]
with open(f'{file_name}_filtered.csv', 'w' if len(argv) == 1 else 'a') as file:
writer(file, delimiter = ';').writerows(np.append(hits, misses, axis = 0))
def _filter(data, threshold):
i = 0
while i < data.shape[0] - 1:
print(i, data.shape[0])
indexes = _similar_indexes(data[i], data[i + 1:], threshold)
data = np.delete(data, indexes, axis = 0)
i += 1
return data
def _similar_indexes(x, y, threshold):
return np.where(np.abs(np.sum(np.abs(y[:, :number_of_inputs] - x[:number_of_inputs]), axis = 1)) < threshold)[0]
if __name__ == '__main__':
filter_data('observations')