Skip to content

Commit

Permalink
Don't automatically compute feature correlations; this is a performan…
Browse files Browse the repository at this point in the history
…ce enhancement
  • Loading branch information
jonathanronen committed Oct 8, 2019
1 parent d406d22 commit 416708d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
23 changes: 20 additions & 3 deletions maui/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,8 @@ def transform(self, X, encoder="mean"):
index=self.x_.index,
columns=[f"LF{i}" for i in range(1, self.n_latent + 1)],
)
self.feature_correlations = maui.utils.correlate_factors_and_features(
self.z_, self.x_
)

self.feature_correlations_ = None
self.w_ = None
return self.z_

Expand Down Expand Up @@ -447,6 +446,24 @@ def get_linear_weights(self):
)
return self.w_

def get_feature_correlations(self):
"""Get correlation coefficients between input features and latent factors.
Returns
-------
r: (n_features, n_latent_factors) DataFrame
r_{ij} is the correlation coefficient between feature `i`
and latent factor `j`.
"""
if (
not hasattr(self, "feature_correlations_")
or self.feature_correlations_ is None
):
self.feature_correlations_ = maui.utils.correlate_factors_and_features(
self.z_, self.x_
)
return self.feature_correlations_

def drop_unexplanatory_factors(self, threshold=0.02):
"""Drops factors which have a low R^2 score in a univariate linear model
predicting the features `x` from a column of the latent factors `z`.
Expand Down
4 changes: 3 additions & 1 deletion test/test_maui.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def test_dict2array():
def test_maui_saves_feature_correlations():
maui_model = Maui(n_hidden=[10], n_latent=2, epochs=1)
z = maui_model.fit_transform({"d1": df1, "d2": df2})
assert hasattr(maui_model, "feature_correlations")
r = maui_model.get_feature_correlations()
assert r is not None
assert hasattr(maui_model, "feature_correlations_")


def test_maui_saves_w():
Expand Down

0 comments on commit 416708d

Please sign in to comment.