@@ -44,6 +44,46 @@ def getData(tmp, cutoff=0, bertpath=""):
44
44
45
45
return _3d , _1d , _2d , _truth
46
46
47
+ # GET DATA
48
+ def getData_from_dict (data , cutoff = 0 , bertpath = "" ):
49
+
50
+ # 3D coordinate information
51
+ idx = data ["idx" ]
52
+ val = data ["val" ]
53
+
54
+ # 1D information
55
+ angles = np .stack ([np .sin (data ["phi" ]),
56
+ np .cos (data ["phi" ]),
57
+ np .sin (data ["psi" ]),
58
+ np .cos (data ["psi" ])], axis = - 1 )
59
+ obt = data ["obt" ].T
60
+ prop = data ["prop" ].T
61
+
62
+ # 2D information
63
+ orientations = np .stack ([data ["omega6d" ], data ["theta6d" ], data ["phi6d" ]], axis = - 1 )
64
+ orientations = np .concatenate ([np .sin (orientations ), np .cos (orientations )], axis = - 1 )
65
+ euler = np .concatenate ([np .sin (data ["euler" ]), np .cos (data ["euler" ])], axis = - 1 )
66
+ maps = data ["maps" ]
67
+ tbt = data ["tbt" ].T
68
+ sep = seqsep (tbt .shape [0 ])
69
+
70
+ # Transform input distance
71
+ tbt [:,:,0 ] = transform (tbt [:,:,0 ])
72
+ maps = transform (maps , cutoff = cutoff )
73
+
74
+ _3d = (idx , val )
75
+ _1d = (np .concatenate ([angles , obt , prop ], axis = - 1 ), None )
76
+
77
+ if bertpath != "" :
78
+ bert = np .load (bertpath )
79
+ bert = np .transpose (bert , [1 ,2 ,0 ])
80
+ _2d = np .concatenate ([tbt , maps , euler , orientations , sep , bert ], axis = - 1 )
81
+ else :
82
+ _2d = np .concatenate ([tbt , maps , euler , orientations , sep ], axis = - 1 )
83
+ _truth = None
84
+
85
+ return _3d , _1d , _2d , _truth
86
+
47
87
# VARIANCE REDUCTION
48
88
def transform (X , cutoff = 4 , scaling = 3.0 ):
49
89
X_prime = np .maximum (X , np .zeros_like (X ) + cutoff ) - cutoff
@@ -90,6 +130,10 @@ def clean(samples, outfolder, ensemble=False, verbose=False):
90
130
if verbose : print ("Removing" , join (outfolder , samples [i ]+ ".features.npz" ))
91
131
if isfile (join (outfolder , samples [i ]+ ".features.npz" )):
92
132
os .remove (join (outfolder , samples [i ]+ ".features.npz" ))
133
+ if isfile (join (outfolder , samples [i ]+ ".fa" )):
134
+ os .remove (join (outfolder , samples [i ]+ ".fa" ))
135
+ if isfile (join (outfolder , "bert_" + samples [i ]+ ".npy" )):
136
+ os .remove (join (outfolder , "bert_" + samples [i ]+ ".npy" ))
93
137
if ensemble :
94
138
for j in ["best" , "second" , "third" , "fourth" ]:
95
139
if verbose : print ("Removing" , join (outfolder , samples [i ]+ "_" + j + ".npz" ))
0 commit comments