Skip to content

Commit a541507

Browse files
author
Naozumi Hiranuma
committed
pyrosetta only reads atom lines now
1 parent a14192a commit a541507

File tree

3 files changed

+29
-7
lines changed

3 files changed

+29
-7
lines changed

DeepAccNet-SILENT.py

+28-5
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from pyrosetta import *
1515
from pyrosetta.rosetta import *
16-
init(extra_options = "-constant_seed -mute all")
16+
init(extra_options = "-constant_seed -mute all -read_only_ATOM_entries")
1717

1818
def get_lddt(estogram, mask, center=7, weights=[1,1,1,1]):
1919
# Remove diagonal from the mask.
@@ -57,6 +57,12 @@ def main():
5757
default=False,
5858
help="Make binder related predictions (Assumes chain A to be a binder).")
5959

60+
parser.add_argument("--savehidden",
61+
"-sh", action="store",
62+
type=str,
63+
default="",
64+
help="saves last hidden layer if not empty (Default: "")")
65+
6066
parser.add_argument("--reprocess",
6167
"-r",
6268
action="store_true",
@@ -119,11 +125,17 @@ def main():
119125
# Open with append
120126
if not isfile(args.outfile) or args.reprocess:
121127
outfile = open(args.outfile, "w")
122-
outfile.write("name, global_lddt, interface_lddt, binder_lddt\n")
128+
if args.binder:
129+
outfile.write("name, global_lddt, interface_lddt, binder_lddt\n")
130+
else:
131+
outfile.write("name, global_lddt\n")
123132
done = []
124133
else:
125134
outfile = open(args.outfile, "a")
126135
done = pd.read_csv(args.outfile)["name"].values
136+
137+
if args.savehidden != "" and not isdir(args.savehidden):
138+
os.mkdir(args.savehidden)
127139

128140
with torch.no_grad():
129141
# Parse through poses
@@ -132,9 +144,10 @@ def main():
132144

133145
input_stream.fill_pose(pose)
134146
name = core.pose.tag_from_pose(pose)
135-
print(name)
136147
if name in done:
148+
print(name, "is already done.")
137149
continue
150+
print("Working on", name)
138151
per_sample_result = [name]
139152

140153
# This is where featurization happens
@@ -148,7 +161,12 @@ def main():
148161
idx_g = torch.Tensor(idx.astype(np.int32)).long().to(device)
149162
val_g = torch.Tensor(val).to(device)
150163

151-
estogram, mask, lddt, dmy = model(idx_g, val_g, f1d_g, f2d_g)
164+
if args.savehidden != "":
165+
estogram, mask, lddt, hidden, dmy = model(idx_g, val_g, f1d_g, f2d_g, output_hidden_layer=True)
166+
hidden = hidden.cpu().detach().numpy()
167+
np.save(join(args.savehidden, name+".npy"), hidden)
168+
else:
169+
estogram, mask, lddt, dmy = model(idx_g, val_g, f1d_g, f2d_g)
152170
lddt = lddt.cpu().detach().numpy()
153171
estogram = estogram.cpu().detach().numpy()
154172
mask = mask.cpu().detach().numpy()
@@ -177,7 +195,12 @@ def main():
177195
val = val[index]
178196
idx_g = torch.Tensor(idx.astype(np.int32)).long().to(device)
179197
val_g = torch.Tensor(val).to(device)
180-
estogram, mask, lddt, dmy = model(idx_g, val_g, f1d_g[:blen], f2d_g[:, :, :blen, :blen])
198+
if args.savehidden != "":
199+
estogram, mask, lddt, hidden, dmy = model(idx_g, val_g, f1d_g[:blen], f2d_g[:, :, :blen, :blen], output_hidden_layer=True)
200+
hidden = hidden.cpu().detach().numpy()
201+
np.save(join(args.savehidden, name+"_b.npy"), hidden)
202+
else:
203+
estogram, mask, lddt, dmy = model(idx_g, val_g, f1d_g[:blen], f2d_g[:, :, :blen, :blen])
181204
lddt = lddt.cpu().detach().numpy()
182205
estogram = estogram.cpu().detach().numpy()
183206
mask = mask.cpu().detach().numpy()

DeepAccNet-noPyRosetta.py

-1
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,6 @@ def main():
176176
###########################
177177

178178
if args.bert:
179-
print("hi", [join(args.output, "bert_"+s+".npy") for s in samples])
180179
samples = [s for s in samples if isfile(join(args.output, s+".features.npz")) and isfile(join(args.output, "bert_"+s+".npy"))]
181180

182181
else:

deepAccNet/featurize.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Instantiate pyrosetta
22
from pyrosetta import *
3-
init(extra_options = "-constant_seed -mute all")
3+
init(extra_options = "-constant_seed -mute all -read_only_ATOM_entries")
44

55
# Import necessary libraries
66
import numpy as np

0 commit comments

Comments
 (0)