13
13
14
14
from pyrosetta import *
15
15
from pyrosetta .rosetta import *
16
- init (extra_options = "-constant_seed -mute all" )
16
+ init (extra_options = "-constant_seed -mute all -read_only_ATOM_entries " )
17
17
18
18
def get_lddt (estogram , mask , center = 7 , weights = [1 ,1 ,1 ,1 ]):
19
19
# Remove diagonal from the mask.
@@ -57,6 +57,12 @@ def main():
57
57
default = False ,
58
58
help = "Make binder related predictions (Assumes chain A to be a binder)." )
59
59
60
+ parser .add_argument ("--savehidden" ,
61
+ "-sh" , action = "store" ,
62
+ type = str ,
63
+ default = "" ,
64
+ help = "saves last hidden layer if not empty (Default: " ")" )
65
+
60
66
parser .add_argument ("--reprocess" ,
61
67
"-r" ,
62
68
action = "store_true" ,
@@ -119,11 +125,17 @@ def main():
119
125
# Open with append
120
126
if not isfile (args .outfile ) or args .reprocess :
121
127
outfile = open (args .outfile , "w" )
122
- outfile .write ("name, global_lddt, interface_lddt, binder_lddt\n " )
128
+ if args .binder :
129
+ outfile .write ("name, global_lddt, interface_lddt, binder_lddt\n " )
130
+ else :
131
+ outfile .write ("name, global_lddt\n " )
123
132
done = []
124
133
else :
125
134
outfile = open (args .outfile , "a" )
126
135
done = pd .read_csv (args .outfile )["name" ].values
136
+
137
+ if args .savehidden != "" and not isdir (args .savehidden ):
138
+ os .mkdir (args .savehidden )
127
139
128
140
with torch .no_grad ():
129
141
# Parse through poses
@@ -132,9 +144,10 @@ def main():
132
144
133
145
input_stream .fill_pose (pose )
134
146
name = core .pose .tag_from_pose (pose )
135
- print (name )
136
147
if name in done :
148
+ print (name , "is already done." )
137
149
continue
150
+ print ("Working on" , name )
138
151
per_sample_result = [name ]
139
152
140
153
# This is where featurization happens
@@ -148,7 +161,12 @@ def main():
148
161
idx_g = torch .Tensor (idx .astype (np .int32 )).long ().to (device )
149
162
val_g = torch .Tensor (val ).to (device )
150
163
151
- estogram , mask , lddt , dmy = model (idx_g , val_g , f1d_g , f2d_g )
164
+ if args .savehidden != "" :
165
+ estogram , mask , lddt , hidden , dmy = model (idx_g , val_g , f1d_g , f2d_g , output_hidden_layer = True )
166
+ hidden = hidden .cpu ().detach ().numpy ()
167
+ np .save (join (args .savehidden , name + ".npy" ), hidden )
168
+ else :
169
+ estogram , mask , lddt , dmy = model (idx_g , val_g , f1d_g , f2d_g )
152
170
lddt = lddt .cpu ().detach ().numpy ()
153
171
estogram = estogram .cpu ().detach ().numpy ()
154
172
mask = mask .cpu ().detach ().numpy ()
@@ -177,7 +195,12 @@ def main():
177
195
val = val [index ]
178
196
idx_g = torch .Tensor (idx .astype (np .int32 )).long ().to (device )
179
197
val_g = torch .Tensor (val ).to (device )
180
- estogram , mask , lddt , dmy = model (idx_g , val_g , f1d_g [:blen ], f2d_g [:, :, :blen , :blen ])
198
+ if args .savehidden != "" :
199
+ estogram , mask , lddt , hidden , dmy = model (idx_g , val_g , f1d_g [:blen ], f2d_g [:, :, :blen , :blen ], output_hidden_layer = True )
200
+ hidden = hidden .cpu ().detach ().numpy ()
201
+ np .save (join (args .savehidden , name + "_b.npy" ), hidden )
202
+ else :
203
+ estogram , mask , lddt , dmy = model (idx_g , val_g , f1d_g [:blen ], f2d_g [:, :, :blen , :blen ])
181
204
lddt = lddt .cpu ().detach ().numpy ()
182
205
estogram = estogram .cpu ().detach ().numpy ()
183
206
mask = mask .cpu ().detach ().numpy ()
0 commit comments