Skip to content

Commit a14192a

Browse files
author
Naozumi Hiranuma
committed
updated the cleaning function
1 parent 6648723 commit a14192a

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

deepAccNet/utils.py

+44
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,46 @@ def getData(tmp, cutoff=0, bertpath=""):
4444

4545
return _3d, _1d, _2d, _truth
4646

47+
# GET DATA
48+
def getData_from_dict(data, cutoff=0, bertpath=""):
49+
50+
# 3D coordinate information
51+
idx = data["idx"]
52+
val = data["val"]
53+
54+
# 1D information
55+
angles = np.stack([np.sin(data["phi"]),
56+
np.cos(data["phi"]),
57+
np.sin(data["psi"]),
58+
np.cos(data["psi"])], axis=-1)
59+
obt = data["obt"].T
60+
prop = data["prop"].T
61+
62+
# 2D information
63+
orientations = np.stack([data["omega6d"], data["theta6d"], data["phi6d"]], axis=-1)
64+
orientations = np.concatenate([np.sin(orientations), np.cos(orientations)], axis=-1)
65+
euler = np.concatenate([np.sin(data["euler"]), np.cos(data["euler"])], axis=-1)
66+
maps = data["maps"]
67+
tbt = data["tbt"].T
68+
sep = seqsep(tbt.shape[0])
69+
70+
# Transform input distance
71+
tbt[:,:,0] = transform(tbt[:,:,0])
72+
maps = transform(maps, cutoff=cutoff)
73+
74+
_3d = (idx, val)
75+
_1d = (np.concatenate([angles, obt, prop], axis=-1), None)
76+
77+
if bertpath!="":
78+
bert = np.load(bertpath)
79+
bert = np.transpose(bert, [1,2,0])
80+
_2d = np.concatenate([tbt, maps, euler, orientations, sep, bert], axis=-1)
81+
else:
82+
_2d = np.concatenate([tbt, maps, euler, orientations, sep], axis=-1)
83+
_truth = None
84+
85+
return _3d, _1d, _2d, _truth
86+
4787
# VARIANCE REDUCTION
4888
def transform(X, cutoff=4, scaling=3.0):
4989
X_prime = np.maximum(X, np.zeros_like(X) + cutoff) - cutoff
@@ -90,6 +130,10 @@ def clean(samples, outfolder, ensemble=False, verbose=False):
90130
if verbose: print("Removing", join(outfolder, samples[i]+".features.npz"))
91131
if isfile(join(outfolder, samples[i]+".features.npz")):
92132
os.remove(join(outfolder, samples[i]+".features.npz"))
133+
if isfile(join(outfolder, samples[i]+".fa")):
134+
os.remove(join(outfolder, samples[i]+".fa"))
135+
if isfile(join(outfolder, "bert_"+samples[i]+".npy")):
136+
os.remove(join(outfolder, "bert_"+samples[i]+".npy"))
93137
if ensemble:
94138
for j in ["best", "second", "third", "fourth"]:
95139
if verbose: print("Removing", join(outfolder, samples[i]+"_"+j+".npz"))

deepAccNet_noPyRosetta/utils.py

+4
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,10 @@ def clean(samples, outfolder, ensemble=False, verbose=False):
9090
if verbose: print("Removing", join(outfolder, samples[i]+".features.npz"))
9191
if isfile(join(outfolder, samples[i]+".features.npz")):
9292
os.remove(join(outfolder, samples[i]+".features.npz"))
93+
if isfile(join(outfolder, samples[i]+".fa")):
94+
os.remove(join(outfolder, samples[i]+".fa"))
95+
if isfile(join(outfolder, "bert_"+samples[i]+".npy")):
96+
os.remove(join(outfolder, "bert_"+samples[i]+".npy"))
9397
if ensemble:
9498
for j in ["best", "second", "third", "fourth"]:
9599
if verbose: print("Removing", join(outfolder, samples[i]+"_"+j+".npz"))

0 commit comments

Comments
 (0)