Skip to content

Commit 21c1775

Browse files
author
Minkyung Baek
committed
Add --per_res_only option. When this flag is provided, it'll write per-residue accuracy estimation only
1 parent c4257af commit 21c1775

File tree

3 files changed

+47
-21
lines changed

3 files changed

+47
-21
lines changed

DeepAccNet.py

+31-13
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ def main():
3434
action="store_true",
3535
default=False,
3636
help="Writing results to a csv file (Default: False)")
37+
38+
parser.add_argument("--per_res_only",
39+
"-pr",
40+
action="store_true",
41+
default=False,
42+
help="Store per-residue accuracy only (Default: False)")
3743

3844
parser.add_argument("--leaveTempFile",
3945
"-lt",
@@ -242,15 +248,23 @@ def main():
242248

243249
if not args.csv:
244250
if args.ensemble:
245-
np.savez_compressed(join(args.output, s+"_"+modelname[:-4]+".npz"),
246-
lddt = lddt.cpu().detach().numpy().astype(np.float16),
247-
estogram = estogram.cpu().detach().numpy().astype(np.float16),
248-
mask = mask.cpu().detach().numpy().astype(np.float16))
251+
if args.per_res_only:
252+
np.savez_compressed(join(args.output, s+"_"+modelname[:-4]+".npz"),
253+
lddt = lddt.cpu().detach().numpy().astype(np.float16))
254+
else:
255+
np.savez_compressed(join(args.output, s+"_"+modelname[:-4]+".npz"),
256+
lddt = lddt.cpu().detach().numpy().astype(np.float16),
257+
estogram = estogram.cpu().detach().numpy().astype(np.float16),
258+
mask = mask.cpu().detach().numpy().astype(np.float16))
249259
else:
250-
np.savez_compressed(join(args.output, s+".npz"),
251-
lddt = lddt.cpu().detach().numpy().astype(np.float16),
252-
estogram = estogram.cpu().detach().numpy().astype(np.float16),
253-
mask = mask.cpu().detach().numpy().astype(np.float16))
260+
if args.per_res_only:
261+
np.savez_compressed(join(args.output, s+".npz"),
262+
lddt = lddt.cpu().detach().numpy().astype(np.float16))
263+
else:
264+
np.savez_compressed(join(args.output, s+".npz"),
265+
lddt = lddt.cpu().detach().numpy().astype(np.float16),
266+
estogram = estogram.cpu().detach().numpy().astype(np.float16),
267+
mask = mask.cpu().detach().numpy().astype(np.float16))
254268
except:
255269
print("Failed to predict for", join(args.output, s+"_"+modelname[:-4]+".npz"))
256270

@@ -311,10 +325,14 @@ def main():
311325
val = torch.Tensor(val).to(device)
312326

313327
estogram, mask, lddt, dmy = model(idx, val, f1d, f2d)
314-
np.savez_compressed(outsamplename+".npz",
315-
lddt = lddt.cpu().detach().numpy().astype(np.float16),
316-
estogram = estogram.cpu().detach().numpy().astype(np.float16),
317-
mask = mask.cpu().detach().numpy().astype(np.float16))
328+
if args.per_res_only:
329+
np.savez_compressed(outsamplename+".npz",
330+
lddt = lddt.cpu().detach().numpy().astype(np.float16))
331+
else:
332+
np.savez_compressed(outsamplename+".npz",
333+
lddt = lddt.cpu().detach().numpy().astype(np.float16),
334+
estogram = estogram.cpu().detach().numpy().astype(np.float16),
335+
mask = mask.cpu().detach().numpy().astype(np.float16))
318336

319337
if not args.leaveTempFile:
320338
dan.clean([outsamplename],
@@ -326,4 +344,4 @@ def main():
326344

327345

328346
if __name__== "__main__":
329-
main()
347+
main()

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ optional arguments:
2020
-h, --help show this help message and exit
2121
--pdb, -pdb Running on a single pdb file instead of a folder (Default: False)
2222
--csv, -csv Writing results to a csv file (Default: False)
23+
--per_res_only, -pr Writing per-residue accuracy only (Default: False)
2324
--leaveTempFile, -lt Leaving temporary files (Default: False)
2425
--process PROCESS, -p PROCESS
2526
Specifying # of cpus to use for featurization (Default: 1)

deepAccNet/utils.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def seqsep(psize, normalizer=100, axis=-1):
9797
ret[i,j] = abs(i-j)*1.0/100-1.0
9898
return np.expand_dims(ret, axis)
9999

100-
def merge(samples, outfolder, verbose=False):
100+
def merge(samples, outfolder, per_res_only=False, verbose=False):
101101
for j in range(len(samples)):
102102
try:
103103
if verbose: print("Merging", samples[j])
@@ -108,19 +108,26 @@ def merge(samples, outfolder, verbose=False):
108108
for i in ["best", "second", "third", "fourth"]:
109109
temp = np.load(join(outfolder, samples[j]+"_"+i+".npz"))
110110
lddt.append(temp["lddt"])
111+
if per_res_only:
112+
continue
111113
estogram.append(temp["estogram"])
112114
mask.append(temp["mask"])
113115

114116
# Averaging
115117
lddt = np.mean(lddt, axis=0)
116-
estogram = np.mean(estogram, axis=0)
117-
mask = np.mean(mask, axis=0)
118+
if not per_res_only:
119+
estogram = np.mean(estogram, axis=0)
120+
mask = np.mean(mask, axis=0)
118121

119122
# Saving
120-
np.savez_compressed(join(outfolder, samples[j]+".npz"),
121-
lddt = lddt.astype(np.float16),
122-
estogram = estogram.astype(np.float16),
123-
mask = mask.astype(np.float16))
123+
if per_res_only:
124+
np.savez_compressed(join(outfolder, samples[j]+".npz"),
125+
lddt = lddt.astype(np.float16))
126+
else:
127+
np.savez_compressed(join(outfolder, samples[j]+".npz"),
128+
lddt = lddt.astype(np.float16),
129+
estogram = estogram.astype(np.float16),
130+
mask = mask.astype(np.float16))
124131
except:
125132
print("Failed to merge for", join(outfolder, samples[j]+".npz"))
126133

@@ -140,4 +147,4 @@ def clean(samples, outfolder, ensemble=False, verbose=False):
140147
if isfile(join(outfolder, samples[i]+"_"+j+".npz")):
141148
os.remove(join(outfolder, samples[i]+"_"+j+".npz"))
142149
except:
143-
print("Failed to clean for", samples[i])
150+
print("Failed to clean for", samples[i])

0 commit comments

Comments
 (0)