1
+ import sys
2
+ import argparse
3
+ import os
4
+ from os import listdir
5
+ from os .path import isfile , isdir , join
6
+ import numpy as np
7
+ import pandas as pd
8
+ import multiprocessing
9
+ import torch
10
+
11
+ def main ():
12
+ #####################
13
+ # Parsing arguments
14
+ #####################
15
+ parser = argparse .ArgumentParser (description = "Error predictor network" ,
16
+ epilog = "v0.0.1" )
17
+ parser .add_argument ("input" ,
18
+ action = "store" ,
19
+ help = "path to input folder or input pdb file" )
20
+
21
+ parser .add_argument ("output" ,
22
+ action = "store" , nargs = argparse .REMAINDER ,
23
+ help = "path to output (folder path, npz, or csv)" )
24
+
25
+ parser .add_argument ("--modelpath" ,
26
+ "-modelpath" ,
27
+ action = "store" ,
28
+ default = "NatComm_standard" ,
29
+ help = "modelpath (default: NatComm_standard" )
30
+
31
+ parser .add_argument ("--pdb" ,
32
+ "-pdb" ,
33
+ action = "store_true" ,
34
+ default = False ,
35
+ help = "Running on a single pdb file instead of a folder (Default: False)" )
36
+
37
+ parser .add_argument ("--csv" ,
38
+ "-csv" ,
39
+ action = "store_true" ,
40
+ default = False ,
41
+ help = "Writing results to a csv file (Default: False)" )
42
+
43
+ parser .add_argument ("--leaveTempFile" ,
44
+ "-lt" ,
45
+ action = "store_true" ,
46
+ default = False ,
47
+ help = "Leaving temporary files (Default: False)" )
48
+
49
+ parser .add_argument ("--process" ,
50
+ "-p" , action = "store" ,
51
+ type = int ,
52
+ default = 1 ,
53
+ help = "Specifying # of cpus to use for featurization (Default: 1)" )
54
+
55
+ parser .add_argument ("--featurize" ,
56
+ "-f" ,
57
+ action = "store_true" ,
58
+ default = False ,
59
+ help = "Running only the featurization part(Default: False)" )
60
+
61
+ parser .add_argument ("--reprocess" ,
62
+ "-r" , action = "store_true" ,
63
+ default = False ,
64
+ help = "Reprocessing all feature files (Default: False)" )
65
+
66
+ parser .add_argument ("--verbose" ,
67
+ "-v" ,
68
+ action = "store_true" ,
69
+ default = False ,
70
+ help = "Activating verbose flag (Default: False)" )
71
+
72
+ parser .add_argument ("--bert" ,
73
+ "-bert" ,
74
+ action = "store_true" ,
75
+ default = False ,
76
+ help = "Run with bert features. Use extractBert.py to generate them. (Default: False)" )
77
+
78
+ parser .add_argument ("--ensemble" ,
79
+ "-e" ,
80
+ action = "store_true" ,
81
+ default = False ,
82
+ help = "Running with ensembling of 4 models. This adds 4x computational time with some overheads (Default: False)" )
83
+
84
+ args = parser .parse_args ()
85
+
86
+ ################################
87
+ # Checking file availabilities #
88
+ ################################
89
+ csvfilename = "result.csv"
90
+
91
+ # made outfolder an optional positinal argument. So check manually it's lenght and unpack the string
92
+ if len (args .output )> 1 :
93
+ print (f"Only one output folder can be specified, but got { args .output } " , file = sys .stderr )
94
+ return - 1
95
+
96
+ if len (args .output )== 0 :
97
+ args .output = ""
98
+ else :
99
+ args .output = args .output [0 ]
100
+
101
+ if args .input .endswith ('.pdb' ):
102
+ args .pdb = True
103
+
104
+ if args .output .endswith (".csv" ):
105
+ args .csv = True
106
+
107
+ if not args .pdb :
108
+ if not isdir (args .input ):
109
+ print ("Input folder does not exist." , file = sys .stderr )
110
+ return - 1
111
+
112
+ #default is input folder
113
+ if args .output == "" :
114
+ args .output = args .input
115
+ else :
116
+ if not args .csv and not isdir (args .output ):
117
+ if args .verbose : print ("Creating output folder:" , args .output )
118
+ os .mkdir (args .output )
119
+
120
+ # if csv, do it in place.
121
+ elif args .csv :
122
+ csvfilename = args .output
123
+ args .output = args .input
124
+
125
+ else :
126
+ if not isfile (args .input ):
127
+ print ("Input file does not exist." , file = sys .stderr )
128
+ return - 1
129
+
130
+ #default is output name with extension changed to npz
131
+ if args .output == "" :
132
+ args .output = os .path .splitext (args .input )[0 ]+ ".npz"
133
+
134
+ if not (".pdb" in args .input and ".npz" in args .output ):
135
+ print ("Input needs to be in .pdb format, and output needs to be in .npz format." , file = sys .stderr )
136
+ return - 1
137
+
138
+ script_dir = os .path .dirname (__file__ )
139
+ base = os .path .join (script_dir , "models/" )
140
+ modelpath = join (base , args .modelpath )
141
+
142
+ # Eensemble is disabled right now.
143
+ if not isdir (modelpath ):
144
+ print ("Model checkpoint does not exist" , file = sys .stderr )
145
+ return - 1
146
+
147
+ ##############################
148
+ # Importing larger libraries #
149
+ ##############################
150
+ script_dir = os .path .dirname (__file__ )
151
+ sys .path .insert (0 , script_dir )
152
+ import deepAccNet as dan
153
+
154
+ num_process = 1
155
+ if args .process > 1 :
156
+ num_process = args .process
157
+
158
+ #########################
159
+ # Getting samples names #
160
+ #########################
161
+ if not args .pdb :
162
+ samples = [i [:- 4 ] for i in os .listdir (args .input ) if isfile (args .input + "/" + i ) and i [- 4 :] == ".pdb" and i [0 ]!= "." ]
163
+ ignored = [i [:- 4 ] for i in os .listdir (args .input ) if not (isfile (args .input + "/" + i ) and i [- 4 :] == ".pdb" and i [0 ]!= "." )]
164
+ if args .verbose :
165
+ print ("# samples:" , len (samples ))
166
+ if len (ignored ) > 0 :
167
+ print ("# files ignored:" , len (ignored ))
168
+
169
+ ##############################
170
+ # Featurization happens here #
171
+ ##############################
172
+ inputs = [join (args .input , s )+ ".pdb" for s in samples ]
173
+ tmpoutputs = [join (args .output , s )+ ".features.npz" for s in samples ]
174
+
175
+ if not args .reprocess :
176
+ arguments = [(inputs [i ], tmpoutputs [i ], args .verbose ) for i in range (len (inputs )) if not isfile (tmpoutputs [i ])]
177
+ already_processed = [(inputs [i ], tmpoutputs [i ], args .verbose ) for i in range (len (inputs )) if isfile (tmpoutputs [i ])]
178
+ if args .verbose :
179
+ print ("Featurizing" , len (arguments ), "samples." , len (already_processed ), "are already processed." )
180
+ else :
181
+ arguments = [(inputs [i ], tmpoutputs [i ], args .verbose ) for i in range (len (inputs ))]
182
+ already_processed = [(inputs [i ], tmpoutputs [i ], args .verbose ) for i in range (len (inputs )) if isfile (tmpoutputs [i ])]
183
+ if args .verbose :
184
+ print ("Featurizing" , len (arguments ), "samples." , len (already_processed ), "are re-processed." )
185
+
186
+ if num_process == 1 :
187
+ for a in arguments :
188
+ dan .process (a )
189
+ else :
190
+ pool = multiprocessing .Pool (num_process )
191
+ out = pool .map (dan .process , arguments )
192
+
193
+ # Exit if only featurization is needed
194
+ if args .featurize :
195
+ return 0
196
+
197
+ if args .verbose : print ("using" , modelpath )
198
+
199
+ ###########################
200
+ # Prediction happens here #
201
+ ###########################
202
+
203
+ if args .bert :
204
+ print ("hi" , [join (args .output , "bert_" + s + ".npy" ) for s in samples ])
205
+ samples = [s for s in samples if isfile (join (args .output , s + ".features.npz" )) and isfile (join (args .output , "bert_" + s + ".npy" ))]
206
+
207
+ else :
208
+ samples = [s for s in samples if isfile (join (args .output , s + ".features.npz" ))]
209
+
210
+ # Load pytorch model:
211
+ if args .ensemble :
212
+ modelnames = ["best.pkl" , "second.pkl" , "third.pkl" , "fourth.pkl" ]
213
+ else :
214
+ modelnames = ["best.pkl" ]
215
+
216
+ result = {}
217
+ for modelname in modelnames :
218
+ model = dan .DeepAccNet (twobody_size = 49 if args .bert else 33 )
219
+ checkpoint = torch .load (join (modelpath , modelname ))
220
+ model .load_state_dict (checkpoint ["model_state_dict" ])
221
+ device = torch .device ("cuda:0" if torch .cuda .is_available () or args .cpu else "cpu" )
222
+ model .to (device )
223
+ model .eval ()
224
+
225
+ for s in samples :
226
+ try :
227
+ with torch .no_grad ():
228
+ if args .verbose : print ("Predicting for" , s )
229
+ filename = join (args .output , s + ".features.npz" )
230
+ if args .bert :
231
+ bertname = join (args .output , "bert_" + s + ".npy" )
232
+ else :
233
+ bertname = ""
234
+ (idx , val ), (f1d , bert ), f2d , dmy = dan .getData (filename , bertpath = bertname )
235
+ f1d = torch .Tensor (f1d ).to (device )
236
+ f2d = torch .Tensor (np .expand_dims (f2d .transpose (2 ,0 ,1 ), 0 )).to (device )
237
+ idx = torch .Tensor (idx .astype (np .int32 )).long ().to (device )
238
+ val = torch .Tensor (val ).to (device )
239
+
240
+ estogram , mask , lddt , dmy = model (idx , val , f1d , f2d )
241
+ t = result .get (s , [])
242
+ t .append (np .mean (lddt .cpu ().detach ().numpy ()))
243
+ result [s ] = t
244
+
245
+ if not args .csv :
246
+ if args .ensemble :
247
+ np .savez_compressed (join (args .output , s + "_" + modelname [:- 4 ]+ ".npz" ),
248
+ lddt = lddt .cpu ().detach ().numpy ().astype (np .float16 ),
249
+ estogram = estogram .cpu ().detach ().numpy ().astype (np .float16 ),
250
+ mask = mask .cpu ().detach ().numpy ().astype (np .float16 ))
251
+ else :
252
+ np .savez_compressed (join (args .output , s + ".npz" ),
253
+ lddt = lddt .cpu ().detach ().numpy ().astype (np .float16 ),
254
+ estogram = estogram .cpu ().detach ().numpy ().astype (np .float16 ),
255
+ mask = mask .cpu ().detach ().numpy ().astype (np .float16 ))
256
+ except :
257
+ print ("Failed to predict for" , join (args .output , s + "_" + modelname [:- 4 ]+ ".npz" ))
258
+
259
+ if not args .csv :
260
+
261
+ if args .ensemble :
262
+ dan .merge (samples , args .output , verbose = args .verbose )
263
+
264
+ if not args .leaveTempFile :
265
+ dan .clean (samples ,
266
+ args .output ,
267
+ verbose = args .verbose ,
268
+ ensemble = args .ensemble )
269
+ else :
270
+ # Take average of outputs
271
+ csvfile = open (csvfilename , "w" )
272
+ csvfile .write ("sample\t cb-lddt\n " )
273
+ for s in samples :
274
+ line = "%s\t %.4f\n " % (s , np .mean (result [s ]))
275
+ csvfile .write (line )
276
+ csvfile .close ()
277
+
278
+ # Processing for single sample
279
+ else :
280
+ infilepath = args .input
281
+ outfilepath = args .output
282
+ infolder = "/" .join (infilepath .split ("/" )[:- 1 ])
283
+ insamplename = infilepath .split ("/" )[- 1 ][:- 4 ]
284
+ outfolder = "/" .join (outfilepath .split ("/" )[:- 1 ])
285
+ outsamplename = outfilepath .split ("/" )[- 1 ][:- 4 ]
286
+ feature_file_name = join (outfolder , outsamplename + ".features.npz" )
287
+ if args .verbose :
288
+ print ("only working on a file:" , outfolder , outsamplename )
289
+ # Process if file does not exists or reprocess flag is set
290
+
291
+ if (not isfile (feature_file_name )) or args .reprocess :
292
+ dan .process ((join (infolder , insamplename + ".pdb" ),
293
+ feature_file_name ,
294
+ args .verbose ))
295
+
296
+ if isfile (feature_file_name ):
297
+ # Load pytorch model:
298
+ model = dan .DeepAccNet ()
299
+ model .load_state_dict (torch .load ("models/regular_rep1/weights.pkl" ))
300
+ device = torch .device ("cuda:0" if torch .cuda .is_available () or args .cpu else "cpu" )
301
+ model .to (device )
302
+ model .eval ()
303
+
304
+ # Actual prediction
305
+ with torch .no_grad ():
306
+ if args .verbose : print ("Predicting for" , outsamplename )
307
+ (idx , val ), (f1d , bert ), f2d , dmy = dan .getData (feature_file_name )
308
+ f1d = torch .Tensor (f1d ).to (device )
309
+ f2d = torch .Tensor (np .expand_dims (f2d .transpose (2 ,0 ,1 ), 0 )).to (device )
310
+ idx = torch .Tensor (idx .astype (np .int32 )).long ().to (device )
311
+ val = torch .Tensor (val ).to (device )
312
+
313
+ estogram , mask , lddt , dmy = model (idx , val , f1d , f2d )
314
+ np .savez_compressed (outsamplename + ".npz" ,
315
+ lddt = lddt .cpu ().detach ().numpy ().astype (np .float16 ),
316
+ estogram = estogram .cpu ().detach ().numpy ().astype (np .float16 ),
317
+ mask = mask .cpu ().detach ().numpy ().astype (np .float16 ))
318
+
319
+ if not args .leaveTempFile :
320
+ dan .clean ([outsamplename ],
321
+ outfolder ,
322
+ verbose = args .verbose ,
323
+ noEnsemble = True )
324
+ else :
325
+ print (f"Feature file does not exist: { feature_file_name } " , file = sys .stderr )
326
+
327
+
328
+ if __name__ == "__main__" :
329
+ main ()
0 commit comments