11
11
import time
12
12
import pandas as pd
13
13
import os
14
+ import glob
14
15
15
16
from pyrosetta import *
16
17
from pyrosetta .rosetta import *
@@ -75,6 +76,14 @@ def main():
75
76
action = "store_true" ,
76
77
default = False ,
77
78
help = "Run with bert features. Use extractBert.py to generate them. (Default: False)" )
79
+
80
+ parser .add_argument ("--features_only" ,
81
+ action = "store_true" ,
82
+ help = "Just dump features" )
83
+
84
+ parser .add_argument ("--prediction_only" ,
85
+ action = "store_true" ,
86
+ help = "Assumes stored features" )
78
87
79
88
args = parser .parse_args ()
80
89
@@ -99,6 +108,21 @@ def main():
99
108
return - 1
100
109
101
110
if args .verbose : print ("using" , modelpath )
111
+
112
+ feature_folder = args .outfile + "_features/"
113
+
114
+ if ( args .features_only ):
115
+ if ( not os .path .exists (feature_folder ) ):
116
+ os .mkdir (feature_folder )
117
+
118
+ if ( args .prediction_only ):
119
+ if ( not os .path .exists (feature_folder )):
120
+ print ("--prediction_only: Features have not been generated. Run with --features_only first or remove this flag." )
121
+ return - 1
122
+
123
+ if ( args .features_only and args .prediction_only ):
124
+ print ("You can't specify both --features_only and --prediction_only at the same time." )
125
+ return - 1
102
126
103
127
##############################
104
128
# Importing larger libraries #
@@ -107,52 +131,75 @@ def main():
107
131
sys .path .insert (0 , script_dir )
108
132
import deepAccNet as dan
109
133
110
- model = dan .DeepAccNet (twobody_size = 49 if args .bert else 33 )
111
- device = torch .device ("cuda:0" if torch .cuda .is_available () else "cpu" )
112
- checkpoint = torch .load (join (modelpath , "best.pkl" ), map_location = device )
113
- model .load_state_dict (checkpoint ["model_state_dict" ])
114
- model .to (device )
115
- model .eval ()
134
+ if ( not args .features_only ):
135
+ model = dan .DeepAccNet (twobody_size = 49 if args .bert else 33 )
136
+ device = torch .device ("cuda:0" if torch .cuda .is_available () else "cpu" )
137
+ checkpoint = torch .load (join (modelpath , "best.pkl" ), map_location = device )
138
+ model .load_state_dict (checkpoint ["model_state_dict" ])
139
+ model .to (device )
140
+ model .eval ()
116
141
117
142
#############################
118
143
# Parse through silent file #
119
144
#############################
120
145
121
- silent_files = utility . vector1_utility_file_FileName ()
122
- for silent_file in basic . options . get_file_vector_option ( "in:file: silent" ):
123
- silent_files . append ( utility . file . FileName ( args .infile ) )
124
- input_stream = core . import_pose . pose_stream . SilentFilePoseInputStream ( args . infile )
146
+ # loading the silent like this allows us to get names without loading poses
147
+ sfd_in = rosetta . core . io . silent . SilentFileData ( rosetta . core . io . silent . SilentFileOptions ())
148
+ sfd_in . read_file ( args .infile )
149
+ names = sfd_in . tags ( )
125
150
126
151
# Open with append
127
152
if not isfile (args .outfile ) or args .reprocess :
128
153
outfile = open (args .outfile , "w" )
129
154
if args .binder :
130
- outfile .write ("name, global_lddt, interface_lddt, binder_lddt\n " )
155
+ outfile .write ("global_lddt interface_lddt binder_lddt description \n " )
131
156
else :
132
- outfile .write ("name, global_lddt\n " )
157
+ outfile .write ("global_lddt description \n " )
133
158
done = []
134
159
else :
135
160
outfile = open (args .outfile , "a" )
136
- done = pd .read_csv (args .outfile )["name" ].values
161
+ done = pd .read_csv (args .outfile , sep = "\s+" )["description" ].values
162
+
137
163
138
164
if args .savehidden != "" and not isdir (args .savehidden ):
139
165
os .mkdir (args .savehidden )
140
166
141
167
with torch .no_grad ():
142
168
# Parse through poses
143
169
pose = core .pose .Pose ()
144
- while input_stream .has_another_pose ():
145
-
146
- input_stream .fill_pose (pose )
147
- name = core .pose .tag_from_pose (pose )
170
+ for name in names :
171
+
148
172
if name in done :
149
173
print (name , "is already done." )
150
174
continue
175
+
176
+
151
177
print ("Working on" , name )
152
178
per_sample_result = [name ]
179
+ feature_file = feature_folder + name
180
+
181
+
153
182
154
183
# This is where featurization happens
155
- features = dan .process_from_pose (pose )
184
+ if ( args .prediction_only ):
185
+ try :
186
+ features = np .load (feature_file + ".npz" )
187
+ except :
188
+ print ("Unable to load features for " + name )
189
+ continue
190
+ else :
191
+ if ( args .features_only and os .path .exists ( feature_file + ".npz" )):
192
+ print (name , "is already done." )
193
+ continue
194
+
195
+ sfd_in .get_structure (name ).fill_pose (pose )
196
+ features = dan .process_from_pose (pose )
197
+ features ['blen' ] = np .array (pose .conformation ().chain_end (1 ) - pose .conformation ().chain_begin (1 ) + 1 )
198
+
199
+ if ( args .features_only ):
200
+ np .savez (feature_file , ** features )
201
+ continue
202
+
156
203
157
204
# This is where prediction happens
158
205
# For the whole
@@ -179,7 +226,7 @@ def main():
179
226
if args .binder :
180
227
181
228
# Binder length
182
- blen = pose . conformation (). chain_end ( 1 ) - pose . conformation (). chain_begin ( 1 ) + 1
229
+ blen = features [ 'blen' ]
183
230
plen = estogram .shape [- 1 ]
184
231
if blen == plen :
185
232
continue
@@ -210,12 +257,15 @@ def main():
210
257
# Write the result
211
258
if args .binder :
212
259
r = per_sample_result
213
- outfile .write ("%s, %5f, %5f, %5f \n " % (r [0 ], r [1 ], r [2 ], r [3 ]))
260
+ outfile .write ("%5f %5f %5f %s \n " % (r [1 ], r [2 ], r [3 ], r [0 ]))
214
261
else :
215
262
r = per_sample_result
216
- outfile .write ("%s, %5f \n " % (r [0 ], r [1 ]))
263
+ outfile .write ("%5f %s \n " % (r [1 ], r [0 ]))
217
264
outfile .flush ()
218
265
os .fsync (outfile .fileno ())
266
+
267
+ if ( args .prediction_only ):
268
+ os .remove (feature_file + ".npz" )
219
269
220
270
outfile .close ()
221
271
0 commit comments