@@ -34,6 +34,12 @@ def main():
34
34
action = "store_true" ,
35
35
default = False ,
36
36
help = "Writing results to a csv file (Default: False)" )
37
+
38
+ parser .add_argument ("--per_res_only" ,
39
+ "-pr" ,
40
+ action = "store_true" ,
41
+ default = False ,
42
+ help = "Store per-residue accuracy only (Default: False)" )
37
43
38
44
parser .add_argument ("--leaveTempFile" ,
39
45
"-lt" ,
@@ -242,15 +248,23 @@ def main():
242
248
243
249
if not args .csv :
244
250
if args .ensemble :
245
- np .savez_compressed (join (args .output , s + "_" + modelname [:- 4 ]+ ".npz" ),
246
- lddt = lddt .cpu ().detach ().numpy ().astype (np .float16 ),
247
- estogram = estogram .cpu ().detach ().numpy ().astype (np .float16 ),
248
- mask = mask .cpu ().detach ().numpy ().astype (np .float16 ))
251
+ if args .per_res_only :
252
+ np .savez_compressed (join (args .output , s + "_" + modelname [:- 4 ]+ ".npz" ),
253
+ lddt = lddt .cpu ().detach ().numpy ().astype (np .float16 ))
254
+ else :
255
+ np .savez_compressed (join (args .output , s + "_" + modelname [:- 4 ]+ ".npz" ),
256
+ lddt = lddt .cpu ().detach ().numpy ().astype (np .float16 ),
257
+ estogram = estogram .cpu ().detach ().numpy ().astype (np .float16 ),
258
+ mask = mask .cpu ().detach ().numpy ().astype (np .float16 ))
249
259
else :
250
- np .savez_compressed (join (args .output , s + ".npz" ),
251
- lddt = lddt .cpu ().detach ().numpy ().astype (np .float16 ),
252
- estogram = estogram .cpu ().detach ().numpy ().astype (np .float16 ),
253
- mask = mask .cpu ().detach ().numpy ().astype (np .float16 ))
260
+ if args .per_res_only :
261
+ np .savez_compressed (join (args .output , s + ".npz" ),
262
+ lddt = lddt .cpu ().detach ().numpy ().astype (np .float16 ))
263
+ else :
264
+ np .savez_compressed (join (args .output , s + ".npz" ),
265
+ lddt = lddt .cpu ().detach ().numpy ().astype (np .float16 ),
266
+ estogram = estogram .cpu ().detach ().numpy ().astype (np .float16 ),
267
+ mask = mask .cpu ().detach ().numpy ().astype (np .float16 ))
254
268
except :
255
269
print ("Failed to predict for" , join (args .output , s + "_" + modelname [:- 4 ]+ ".npz" ))
256
270
@@ -311,10 +325,14 @@ def main():
311
325
val = torch .Tensor (val ).to (device )
312
326
313
327
estogram , mask , lddt , dmy = model (idx , val , f1d , f2d )
314
- np .savez_compressed (outsamplename + ".npz" ,
315
- lddt = lddt .cpu ().detach ().numpy ().astype (np .float16 ),
316
- estogram = estogram .cpu ().detach ().numpy ().astype (np .float16 ),
317
- mask = mask .cpu ().detach ().numpy ().astype (np .float16 ))
328
+ if args .per_res_only :
329
+ np .savez_compressed (outsamplename + ".npz" ,
330
+ lddt = lddt .cpu ().detach ().numpy ().astype (np .float16 ))
331
+ else :
332
+ np .savez_compressed (outsamplename + ".npz" ,
333
+ lddt = lddt .cpu ().detach ().numpy ().astype (np .float16 ),
334
+ estogram = estogram .cpu ().detach ().numpy ().astype (np .float16 ),
335
+ mask = mask .cpu ().detach ().numpy ().astype (np .float16 ))
318
336
319
337
if not args .leaveTempFile :
320
338
dan .clean ([outsamplename ],
@@ -326,4 +344,4 @@ def main():
326
344
327
345
328
346
if __name__ == "__main__" :
329
- main ()
347
+ main ()
0 commit comments