-
Notifications
You must be signed in to change notification settings - Fork 0
/
Davis_train.py
105 lines (85 loc) · 3.62 KB
/
Davis_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from Davis_gnn import GNNNet
from utils import *
from emetrics import *
import torch.nn as nn
from data_process import create_dataset_for_train
# datasets = [['davis', 'kiba'][int(sys.argv[1])]]
datasets = ['davis']
print('datasets:',datasets)
#cuda_name = ['cuda:0', 'cuda:1', 'cuda:2', 'cuda:3'][int(sys.argv[2])]
cuda_name = 'cuda:1'
print('cuda_name:', cuda_name)
#fold = [0, 1, 2, 3, 4][int(sys.argv[3])]
fold = 2
cross_validation_flag = True
# print(int(sys.argv[3]))
TRAIN_BATCH_SIZE = 512
TEST_BATCH_SIZE = 512
print('TRAIN_BATCH_SIZE',TRAIN_BATCH_SIZE)
LR = 0.0007
NUM_EPOCHS = 2000
model_name = 'Davis_train'
print("model_name",model_name)
print('Learning rate: ', LR)
print('Epochs: ', NUM_EPOCHS)
models_dir = 'models'
results_dir = 'Results'
if not os.path.exists(models_dir):
os.makedirs(models_dir)
if not os.path.exists(results_dir):
os.makedirs(results_dir)
# Main program: iterate over different datasets
result_str = ''
USE_CUDA = torch.cuda.is_available()
device = torch.device(cuda_name if USE_CUDA else 'cpu') #cuda:0
#device = torch.device('cpu')
model = GNNNet()
model_st = GNNNet.__name__
dataset = datasets[0]
model_file_name = 'models/'+model_name+'.model'
# model.load_state_dict(torch.load(model_file_name, map_location=cuda_name))
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
for dataset in datasets:
train_path = 'PreData/train_'+str(fold)+'.pt'
valid_path = 'PreData/valid_'+str(fold)+'.pt'
train_data = torch.load(train_path)
valid_data = torch.load(valid_path)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=TRAIN_BATCH_SIZE, shuffle=True,
collate_fn=collate, num_workers=8, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=TEST_BATCH_SIZE, shuffle=False,
collate_fn=collate, num_workers=8, pin_memory=True)
best_mse = 1000
best_test_mse = 1000
best_epoch = -1
model_file_name = 'models/'+model_name + model_st + '_' + dataset + '_' + str(fold) + '.model'
for epoch in range(NUM_EPOCHS):
#
train(model, device, train_loader, optimizer, epoch + 1)
print('predicting for valid data')
G, P = predicting(model, device, valid_loader)
val = get_mse(G, P)
print('valid result:', val, best_mse)
if val < best_mse:
best_mse = val
best_epoch = epoch + 1
torch.save(model.state_dict(), model_file_name)
cindex = get_cindex(G, P) # DeepDTA
cindex2 = get_ci(G, P) # GraphDTA
rm2 = get_rm2(G, P) # DeepDTA
pearson = get_pearson(G, P)
spearman = get_spearman(G, P)
rmse = get_rmse(G, P)
result_file_name = 'Results/'+ model_name +'.txt'
result_str = "in epoch"+str(best_epoch)+\
" best_mse:"+str(best_mse)+\
" best_pearson:"+str(pearson)+\
" best_rm2:"+str(rm2)+\
" best_cindex:"+str(cindex)+\
" best_cindex2:"+str(cindex2)+\
" best_spearman:"+str(spearman)+\
" best_rmse"+str(rmse)
open(result_file_name, 'w').writelines(result_str)
print('rmse improved at epoch ', best_epoch, '; best_test_mse', best_mse, model_st, dataset, fold)
else:
print('No improvement since epoch ', best_epoch, '; best_test_mse', best_mse, model_st, dataset, fold)