diff --git a/maui/model.py b/maui/model.py index 4cce516..95e37ca 100644 --- a/maui/model.py +++ b/maui/model.py @@ -186,9 +186,8 @@ def transform(self, X, encoder="mean"): index=self.x_.index, columns=[f"LF{i}" for i in range(1, self.n_latent + 1)], ) - self.feature_correlations = maui.utils.correlate_factors_and_features( - self.z_, self.x_ - ) + + self.feature_correlations_ = None self.w_ = None return self.z_ @@ -447,6 +446,24 @@ def get_linear_weights(self): ) return self.w_ + def get_feature_correlations(self): + """Get correlation coefficients between input features and latent factors. + + Returns + ------- + r: (n_features, n_latent_factors) DataFrame + r_{ij} is the correlation coefficient between feature `i` + and latent factor `j`. + """ + if ( + not hasattr(self, "feature_correlations_") + or self.feature_correlations_ is None + ): + self.feature_correlations_ = maui.utils.correlate_factors_and_features( + self.z_, self.x_ + ) + return self.feature_correlations_ + def drop_unexplanatory_factors(self, threshold=0.02): """Drops factors which have a low R^2 score in a univariate linear model predicting the features `x` from a column of the latent factors `z`. diff --git a/test/test_maui.py b/test/test_maui.py index 513ee68..a876e00 100644 --- a/test/test_maui.py +++ b/test/test_maui.py @@ -67,7 +67,9 @@ def test_dict2array(): def test_maui_saves_feature_correlations(): maui_model = Maui(n_hidden=[10], n_latent=2, epochs=1) z = maui_model.fit_transform({"d1": df1, "d2": df2}) - assert hasattr(maui_model, "feature_correlations") + r = maui_model.get_feature_correlations() + assert r is not None + assert hasattr(maui_model, "feature_correlations_") def test_maui_saves_w():