From 991cd0b38fdeb19a6220c065a77a12e8a9552178 Mon Sep 17 00:00:00 2001 From: LeoGrin Date: Tue, 21 Jan 2025 17:52:36 +0100 Subject: [PATCH] run precommit --- src/tabpfn/classifier.py | 2 +- src/tabpfn/preprocessing.py | 2 +- src/tabpfn/regressor.py | 8 ++++---- tests/test_classifier_interface.py | 23 +++++++++++---------- tests/test_regressor_interface.py | 33 +++++++++++++++--------------- 5 files changed, 34 insertions(+), 34 deletions(-) diff --git a/src/tabpfn/classifier.py b/src/tabpfn/classifier.py index a66afcda..61f96978 100644 --- a/src/tabpfn/classifier.py +++ b/src/tabpfn/classifier.py @@ -33,13 +33,13 @@ determine_precision, initialize_tabpfn_model, ) +from tabpfn.config import ModelInterfaceConfig from tabpfn.constants import ( PROBABILITY_EPSILON_ROUND_ZERO, SKLEARN_16_DECIMAL_PRECISION, XType, YType, ) -from tabpfn.config import ModelInterfaceConfig from tabpfn.preprocessing import ( ClassifierEnsembleConfig, EnsembleConfig, diff --git a/src/tabpfn/preprocessing.py b/src/tabpfn/preprocessing.py index 6c7cf506..85c232b6 100644 --- a/src/tabpfn/preprocessing.py +++ b/src/tabpfn/preprocessing.py @@ -153,7 +153,7 @@ def to_dict(self) -> dict: } @classmethod - def from_dict(cls, config_dict: dict) -> "PreprocessorConfig": + def from_dict(cls, config_dict: dict) -> PreprocessorConfig: """Create a config from a dictionary. Args: diff --git a/src/tabpfn/regressor.py b/src/tabpfn/regressor.py index 91073e6e..68774905 100644 --- a/src/tabpfn/regressor.py +++ b/src/tabpfn/regressor.py @@ -37,10 +37,6 @@ determine_precision, initialize_tabpfn_model, ) -from tabpfn.constants import ( - XType, - YType, -) from tabpfn.config import ModelInterfaceConfig from tabpfn.model.bar_distribution import FullSupportBarDistribution from tabpfn.model.preprocessing import ( @@ -70,6 +66,10 @@ from sklearn.pipeline import Pipeline from torch.types import _dtype + from tabpfn.constants import ( + XType, + YType, + ) from tabpfn.inference import ( InferenceEngine, ) diff --git a/tests/test_classifier_interface.py b/tests/test_classifier_interface.py index a14ba98d..dfd5790a 100644 --- a/tests/test_classifier_interface.py +++ b/tests/test_classifier_interface.py @@ -171,49 +171,50 @@ def test_classifier_in_pipeline(X_y: tuple[np.ndarray, np.ndarray]) -> None: rtol=0.1, ), "Class probabilities are not properly balanced in pipeline" + def test_dict_vs_object_preprocessor_config(X_y: tuple[np.ndarray, np.ndarray]) -> None: """Test that dict configs behave identically to PreprocessorConfig objects.""" X, y = X_y - + # Define same config as both dict and object dict_config = { "name": "quantile_uni_coarse", - "append_original": False, # changed from default + "append_original": False, # changed from default "categorical_name": "ordinal_very_common_categories_shuffled", "global_transformer_name": "svd", "subsample_features": -1, } - + object_config = PreprocessorConfig( name="quantile_uni_coarse", - append_original=False, # changed from default + append_original=False, # changed from default categorical_name="ordinal_very_common_categories_shuffled", global_transformer_name="svd", subsample_features=-1, ) - + # Create two models with same random state model_dict = TabPFNClassifier( inference_config={"PREPROCESS_TRANSFORMS": [dict_config]}, n_estimators=2, - random_state=42 + random_state=42, ) - + model_obj = TabPFNClassifier( inference_config={"PREPROCESS_TRANSFORMS": [object_config]}, n_estimators=2, - random_state=42 + random_state=42, ) - + # Fit both models model_dict.fit(X, y) model_obj.fit(X, y) - + # Compare predictions pred_dict = model_dict.predict(X) pred_obj = model_obj.predict(X) np.testing.assert_array_equal(pred_dict, pred_obj) - + # Compare probabilities prob_dict = model_dict.predict_proba(X) prob_obj = model_obj.predict_proba(X) diff --git a/tests/test_regressor_interface.py b/tests/test_regressor_interface.py index e7ca781b..ccd48226 100644 --- a/tests/test_regressor_interface.py +++ b/tests/test_regressor_interface.py @@ -157,64 +157,63 @@ def test_regressor_in_pipeline(X_y: tuple[np.ndarray, np.ndarray]) -> None: X.shape[0], ), "Quantile predictions shape is incorrect" + def test_dict_vs_object_preprocessor_config(X_y: tuple[np.ndarray, np.ndarray]) -> None: """Test that dict configs behave identically to PreprocessorConfig objects.""" X, y = X_y - + # Define same config as both dict and object dict_config = { "name": "quantile_uni", - "append_original": False, # changed from default + "append_original": False, # changed from default "categorical_name": "ordinal_very_common_categories_shuffled", "global_transformer_name": "svd", "subsample_features": -1, } - + object_config = PreprocessorConfig( name="quantile_uni", - append_original=False, # changed from default + append_original=False, # changed from default categorical_name="ordinal_very_common_categories_shuffled", global_transformer_name="svd", subsample_features=-1, ) - + # Create two models with same random state model_dict = TabPFNRegressor( inference_config={"PREPROCESS_TRANSFORMS": [dict_config]}, n_estimators=2, - random_state=42 + random_state=42, ) - + model_obj = TabPFNRegressor( inference_config={"PREPROCESS_TRANSFORMS": [object_config]}, n_estimators=2, - random_state=42 + random_state=42, ) - + # Fit both models model_dict.fit(X, y) model_obj.fit(X, y) - + # Compare predictions for different output types for output_type in ["mean", "median", "mode"]: pred_dict = model_dict.predict(X, output_type=output_type) pred_obj = model_obj.predict(X, output_type=output_type) np.testing.assert_array_almost_equal( - pred_dict, + pred_dict, pred_obj, - err_msg=f"Predictions differ for output_type={output_type}" + err_msg=f"Predictions differ for output_type={output_type}", ) - + # Compare quantile predictions quantiles = [0.1, 0.5, 0.9] quant_dict = model_dict.predict(X, output_type="quantiles", quantiles=quantiles) quant_obj = model_obj.predict(X, output_type="quantiles", quantiles=quantiles) - + for q_dict, q_obj in zip(quant_dict, quant_obj): np.testing.assert_array_almost_equal( q_dict, q_obj, - err_msg=f"Quantile predictions differ" + err_msg="Quantile predictions differ", ) - -