Skip to content

Commit 738b6d8

Browse files
author
Naozumi Hiranuma
committed
script and figure
1 parent 910974b commit 738b6d8

File tree

4 files changed

+329
-0
lines changed

4 files changed

+329
-0
lines changed

DeepAccNet.py

+329
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,329 @@
1+
import sys
2+
import argparse
3+
import os
4+
from os import listdir
5+
from os.path import isfile, isdir, join
6+
import numpy as np
7+
import pandas as pd
8+
import multiprocessing
9+
import torch
10+
11+
def main():
12+
#####################
13+
# Parsing arguments
14+
#####################
15+
parser = argparse.ArgumentParser(description="Error predictor network",
16+
epilog="v0.0.1")
17+
parser.add_argument("input",
18+
action="store",
19+
help="path to input folder or input pdb file")
20+
21+
parser.add_argument("output",
22+
action="store", nargs=argparse.REMAINDER,
23+
help="path to output (folder path, npz, or csv)")
24+
25+
parser.add_argument("--modelpath",
26+
"-modelpath",
27+
action="store",
28+
default="NatComm_standard",
29+
help="modelpath (default: NatComm_standard")
30+
31+
parser.add_argument("--pdb",
32+
"-pdb",
33+
action="store_true",
34+
default=False,
35+
help="Running on a single pdb file instead of a folder (Default: False)")
36+
37+
parser.add_argument("--csv",
38+
"-csv",
39+
action="store_true",
40+
default=False,
41+
help="Writing results to a csv file (Default: False)")
42+
43+
parser.add_argument("--leaveTempFile",
44+
"-lt",
45+
action="store_true",
46+
default=False,
47+
help="Leaving temporary files (Default: False)")
48+
49+
parser.add_argument("--process",
50+
"-p", action="store",
51+
type=int,
52+
default=1,
53+
help="Specifying # of cpus to use for featurization (Default: 1)")
54+
55+
parser.add_argument("--featurize",
56+
"-f",
57+
action="store_true",
58+
default=False,
59+
help="Running only the featurization part(Default: False)")
60+
61+
parser.add_argument("--reprocess",
62+
"-r", action="store_true",
63+
default=False,
64+
help="Reprocessing all feature files (Default: False)")
65+
66+
parser.add_argument("--verbose",
67+
"-v",
68+
action="store_true",
69+
default=False,
70+
help="Activating verbose flag (Default: False)")
71+
72+
parser.add_argument("--bert",
73+
"-bert",
74+
action="store_true",
75+
default=False,
76+
help="Run with bert features. Use extractBert.py to generate them. (Default: False)")
77+
78+
parser.add_argument("--ensemble",
79+
"-e",
80+
action="store_true",
81+
default=False,
82+
help="Running with ensembling of 4 models. This adds 4x computational time with some overheads (Default: False)")
83+
84+
args = parser.parse_args()
85+
86+
################################
87+
# Checking file availabilities #
88+
################################
89+
csvfilename = "result.csv"
90+
91+
# made outfolder an optional positinal argument. So check manually it's lenght and unpack the string
92+
if len(args.output)>1:
93+
print(f"Only one output folder can be specified, but got {args.output}", file=sys.stderr)
94+
return -1
95+
96+
if len(args.output)==0:
97+
args.output = ""
98+
else:
99+
args.output = args.output[0]
100+
101+
if args.input.endswith('.pdb'):
102+
args.pdb = True
103+
104+
if args.output.endswith(".csv"):
105+
args.csv = True
106+
107+
if not args.pdb:
108+
if not isdir(args.input):
109+
print("Input folder does not exist.", file=sys.stderr)
110+
return -1
111+
112+
#default is input folder
113+
if args.output == "":
114+
args.output = args.input
115+
else:
116+
if not args.csv and not isdir(args.output):
117+
if args.verbose: print("Creating output folder:", args.output)
118+
os.mkdir(args.output)
119+
120+
# if csv, do it in place.
121+
elif args.csv:
122+
csvfilename = args.output
123+
args.output = args.input
124+
125+
else:
126+
if not isfile(args.input):
127+
print("Input file does not exist.", file=sys.stderr)
128+
return -1
129+
130+
#default is output name with extension changed to npz
131+
if args.output == "":
132+
args.output = os.path.splitext(args.input)[0]+".npz"
133+
134+
if not(".pdb" in args.input and ".npz" in args.output):
135+
print("Input needs to be in .pdb format, and output needs to be in .npz format.", file=sys.stderr)
136+
return -1
137+
138+
script_dir = os.path.dirname(__file__)
139+
base = os.path.join(script_dir, "models/")
140+
modelpath = join(base, args.modelpath)
141+
142+
# Eensemble is disabled right now.
143+
if not isdir(modelpath):
144+
print("Model checkpoint does not exist", file=sys.stderr)
145+
return -1
146+
147+
##############################
148+
# Importing larger libraries #
149+
##############################
150+
script_dir = os.path.dirname(__file__)
151+
sys.path.insert(0, script_dir)
152+
import deepAccNet as dan
153+
154+
num_process = 1
155+
if args.process > 1:
156+
num_process = args.process
157+
158+
#########################
159+
# Getting samples names #
160+
#########################
161+
if not args.pdb:
162+
samples = [i[:-4] for i in os.listdir(args.input) if isfile(args.input+"/"+i) and i[-4:] == ".pdb" and i[0]!="."]
163+
ignored = [i[:-4] for i in os.listdir(args.input) if not(isfile(args.input+"/"+i) and i[-4:] == ".pdb" and i[0]!=".")]
164+
if args.verbose:
165+
print("# samples:", len(samples))
166+
if len(ignored) > 0:
167+
print("# files ignored:", len(ignored))
168+
169+
##############################
170+
# Featurization happens here #
171+
##############################
172+
inputs = [join(args.input, s)+".pdb" for s in samples]
173+
tmpoutputs = [join(args.output, s)+".features.npz" for s in samples]
174+
175+
if not args.reprocess:
176+
arguments = [(inputs[i], tmpoutputs[i], args.verbose) for i in range(len(inputs)) if not isfile(tmpoutputs[i])]
177+
already_processed = [(inputs[i], tmpoutputs[i], args.verbose) for i in range(len(inputs)) if isfile(tmpoutputs[i])]
178+
if args.verbose:
179+
print("Featurizing", len(arguments), "samples.", len(already_processed), "are already processed.")
180+
else:
181+
arguments = [(inputs[i], tmpoutputs[i], args.verbose) for i in range(len(inputs))]
182+
already_processed = [(inputs[i], tmpoutputs[i], args.verbose) for i in range(len(inputs)) if isfile(tmpoutputs[i])]
183+
if args.verbose:
184+
print("Featurizing", len(arguments), "samples.", len(already_processed), "are re-processed.")
185+
186+
if num_process == 1:
187+
for a in arguments:
188+
dan.process(a)
189+
else:
190+
pool = multiprocessing.Pool(num_process)
191+
out = pool.map(dan.process, arguments)
192+
193+
# Exit if only featurization is needed
194+
if args.featurize:
195+
return 0
196+
197+
if args.verbose: print("using", modelpath)
198+
199+
###########################
200+
# Prediction happens here #
201+
###########################
202+
203+
if args.bert:
204+
print("hi", [join(args.output, "bert_"+s+".npy") for s in samples])
205+
samples = [s for s in samples if isfile(join(args.output, s+".features.npz")) and isfile(join(args.output, "bert_"+s+".npy"))]
206+
207+
else:
208+
samples = [s for s in samples if isfile(join(args.output, s+".features.npz"))]
209+
210+
# Load pytorch model:
211+
if args.ensemble:
212+
modelnames = ["best.pkl", "second.pkl", "third.pkl", "fourth.pkl"]
213+
else:
214+
modelnames = ["best.pkl"]
215+
216+
result = {}
217+
for modelname in modelnames:
218+
model = dan.DeepAccNet(twobody_size = 49 if args.bert else 33)
219+
checkpoint = torch.load(join(modelpath, modelname))
220+
model.load_state_dict(checkpoint["model_state_dict"])
221+
device = torch.device("cuda:0" if torch.cuda.is_available() or args.cpu else "cpu")
222+
model.to(device)
223+
model.eval()
224+
225+
for s in samples:
226+
try:
227+
with torch.no_grad():
228+
if args.verbose: print("Predicting for", s)
229+
filename = join(args.output, s+".features.npz")
230+
if args.bert:
231+
bertname = join(args.output, "bert_"+s+".npy")
232+
else:
233+
bertname = ""
234+
(idx, val), (f1d, bert), f2d, dmy = dan.getData(filename, bertpath = bertname)
235+
f1d = torch.Tensor(f1d).to(device)
236+
f2d = torch.Tensor(np.expand_dims(f2d.transpose(2,0,1), 0)).to(device)
237+
idx = torch.Tensor(idx.astype(np.int32)).long().to(device)
238+
val = torch.Tensor(val).to(device)
239+
240+
estogram, mask, lddt, dmy = model(idx, val, f1d, f2d)
241+
t = result.get(s, [])
242+
t.append(np.mean(lddt.cpu().detach().numpy()))
243+
result[s] = t
244+
245+
if not args.csv:
246+
if args.ensemble:
247+
np.savez_compressed(join(args.output, s+"_"+modelname[:-4]+".npz"),
248+
lddt = lddt.cpu().detach().numpy().astype(np.float16),
249+
estogram = estogram.cpu().detach().numpy().astype(np.float16),
250+
mask = mask.cpu().detach().numpy().astype(np.float16))
251+
else:
252+
np.savez_compressed(join(args.output, s+".npz"),
253+
lddt = lddt.cpu().detach().numpy().astype(np.float16),
254+
estogram = estogram.cpu().detach().numpy().astype(np.float16),
255+
mask = mask.cpu().detach().numpy().astype(np.float16))
256+
except:
257+
print("Failed to predict for", join(args.output, s+"_"+modelname[:-4]+".npz"))
258+
259+
if not args.csv:
260+
261+
if args.ensemble:
262+
dan.merge(samples, args.output, verbose=args.verbose)
263+
264+
if not args.leaveTempFile:
265+
dan.clean(samples,
266+
args.output,
267+
verbose=args.verbose,
268+
ensemble=args.ensemble)
269+
else:
270+
# Take average of outputs
271+
csvfile = open(csvfilename, "w")
272+
csvfile.write("sample\tcb-lddt\n")
273+
for s in samples:
274+
line = "%s\t%.4f\n"%(s, np.mean(result[s]))
275+
csvfile.write(line)
276+
csvfile.close()
277+
278+
# Processing for single sample
279+
else:
280+
infilepath = args.input
281+
outfilepath = args.output
282+
infolder = "/".join(infilepath.split("/")[:-1])
283+
insamplename = infilepath.split("/")[-1][:-4]
284+
outfolder = "/".join(outfilepath.split("/")[:-1])
285+
outsamplename = outfilepath.split("/")[-1][:-4]
286+
feature_file_name = join(outfolder, outsamplename+".features.npz")
287+
if args.verbose:
288+
print("only working on a file:", outfolder, outsamplename)
289+
# Process if file does not exists or reprocess flag is set
290+
291+
if (not isfile(feature_file_name)) or args.reprocess:
292+
dan.process((join(infolder, insamplename+".pdb"),
293+
feature_file_name,
294+
args.verbose))
295+
296+
if isfile(feature_file_name):
297+
# Load pytorch model:
298+
model = dan.DeepAccNet()
299+
model.load_state_dict(torch.load("models/regular_rep1/weights.pkl"))
300+
device = torch.device("cuda:0" if torch.cuda.is_available() or args.cpu else "cpu")
301+
model.to(device)
302+
model.eval()
303+
304+
# Actual prediction
305+
with torch.no_grad():
306+
if args.verbose: print("Predicting for", outsamplename)
307+
(idx, val), (f1d, bert), f2d, dmy = dan.getData(feature_file_name)
308+
f1d = torch.Tensor(f1d).to(device)
309+
f2d = torch.Tensor(np.expand_dims(f2d.transpose(2,0,1), 0)).to(device)
310+
idx = torch.Tensor(idx.astype(np.int32)).long().to(device)
311+
val = torch.Tensor(val).to(device)
312+
313+
estogram, mask, lddt, dmy = model(idx, val, f1d, f2d)
314+
np.savez_compressed(outsamplename+".npz",
315+
lddt = lddt.cpu().detach().numpy().astype(np.float16),
316+
estogram = estogram.cpu().detach().numpy().astype(np.float16),
317+
mask = mask.cpu().detach().numpy().astype(np.float16))
318+
319+
if not args.leaveTempFile:
320+
dan.clean([outsamplename],
321+
outfolder,
322+
verbose=args.verbose,
323+
noEnsemble=True)
324+
else:
325+
print(f"Feature file does not exist: {feature_file_name}", file=sys.stderr)
326+
327+
328+
if __name__== "__main__":
329+
main()

figures/concept.png

59.1 KB
Loading

figures/concept2.png

407 KB
Loading

figures/ipdlogo.png

20.1 KB
Loading

0 commit comments

Comments
 (0)