Skip to content

Commit

Permalink
Merge pull request #9 from PriorLabs/fix_phe_with_client
Browse files Browse the repository at this point in the history
Fix phe and hpo with TabPFN API
  • Loading branch information
LeoGrin authored Jan 31, 2025
2 parents 3dddd58 + 0eefc67 commit 44ba312
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 29 deletions.
2 changes: 1 addition & 1 deletion src/tabpfn_extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@
__version__ = "0.1.0.dev0"

from .utils import is_tabpfn
from .utils import TabPFNRegressor, TabPFNClassifier, PreprocessorConfig
from .utils import TabPFNRegressor, TabPFNClassifier

from . import utils_todo
15 changes: 7 additions & 8 deletions src/tabpfn_extensions/hpo/search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from hyperopt import hp
from pathlib import Path
from tabpfn_extensions import PreprocessorConfig


def enumerate_preprocess_transforms():
Expand Down Expand Up @@ -47,13 +46,13 @@ def enumerate_preprocess_transforms():
for global_transformer_name in [None, "svd"]:
transforms += [
[
PreprocessorConfig(
name=name,
global_transformer_name=global_transformer_name,
subsample_features=subsample_features,
categorical_name=categorical_name,
append_original=append_original,
)
{
"name": name,
"global_transformer_name": global_transformer_name,
"subsample_features": subsample_features,
"categorical_name": categorical_name,
"append_original": append_original,
}
for name in names
],
]
Expand Down
128 changes: 108 additions & 20 deletions src/tabpfn_extensions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
import os

import os
from typing import Any, Type, Tuple, Protocol

from typing import Any, Type, Tuple, Protocol, Literal, Optional, Union, Dict
from dataclasses import dataclass
from typing_extensions import override
from sklearn.base import BaseEstimator
import numpy as np
import warnings

class TabPFNEstimator(Protocol):
def fit(self, X: Any, y: Any) -> Any:
Expand All @@ -29,7 +33,6 @@ def is_tabpfn(estimator: Any) -> bool:
except (AttributeError, TypeError):
return False


from typing import Tuple, Type
import os
USE_TABPFN_LOCAL = os.getenv("USE_TABPFN_LOCAL", "true").lower() == "true"
Expand All @@ -39,9 +42,8 @@ def get_tabpfn_models() -> Tuple[Type, Type, Type]:
if USE_TABPFN_LOCAL:
try:
from tabpfn import TabPFNClassifier, TabPFNRegressor
from tabpfn.preprocessing import PreprocessorConfig

return TabPFNClassifier, TabPFNRegressor, PreprocessorConfig
return TabPFNClassifier, TabPFNRegressor
except ImportError:
pass

Expand All @@ -50,22 +52,108 @@ def get_tabpfn_models() -> Tuple[Type, Type, Type]:
TabPFNClassifier as ClientTabPFNClassifier,
TabPFNRegressor as ClientTabPFNRegressor,
)
from tabpfn_client.estimator import (
PreprocessorConfig as ClientPreprocessorConfig,
)

# Wrapper classes to add device parameter
class TabPFNClassifier(ClientTabPFNClassifier):
def __init__(self, *args, device=None, **kwargs):
super().__init__(*args, **kwargs)
# Ignoring the device parameter for now

class TabPFNRegressor(ClientTabPFNRegressor):
def __init__(self, *args, device=None, **kwargs):
super().__init__(*args, **kwargs)
# Ignoring the device parameter for now

return TabPFNClassifier, TabPFNRegressor, ClientPreprocessorConfig
# Wrapper classes to add device parameter
# we can't use *args because scikit-learn needs to know the parameters of the constructor
class TabPFNClassifierWrapper(ClientTabPFNClassifier):
def __init__(
self,
device: Union[str, None] = None,
categorical_features_indices: Optional[list[int]] = None,
model_path: str = "default",
n_estimators: int = 4,
softmax_temperature: float = 0.9,
balance_probabilities: bool = False,
average_before_softmax: bool = False,
ignore_pretraining_limits: bool = False,
inference_precision: Literal["autocast", "auto"] = "auto",
random_state: Optional[Union[int, np.random.RandomState, np.random.Generator]] = None,
inference_config: Optional[Dict] = None,
paper_version: bool = False,
) -> None:
self.device = device
#TODO: we should support this argument in the client version
self.categorical_features_indices = categorical_features_indices
if categorical_features_indices is not None:
warnings.warn(
"categorical_features_indices is not supported in the client version of TabPFN and will be ignored",
UserWarning,
stacklevel=2,
)
if "/" in model_path:
model_name = model_path.split("/")[-1].split("-")[-1].split(".")[0]
if model_name == "classifier":
model_name = "default"
self.model_path = model_name
else:
self.model_path = model_path

super().__init__(
model_path=self.model_path,
n_estimators=n_estimators,
softmax_temperature=softmax_temperature,
balance_probabilities=balance_probabilities,
average_before_softmax=average_before_softmax,
ignore_pretraining_limits=ignore_pretraining_limits,
inference_precision=inference_precision,
random_state=random_state,
inference_config=inference_config,
paper_version=paper_version,
)

def get_params(self, deep: bool = True) -> Dict[str, Any]:
"""Return parameters for this estimator."""
params = super().get_params(deep=deep)
params.pop("device")
params.pop("categorical_features_indices")
return params

class TabPFNRegressorWrapper(ClientTabPFNRegressor):
def __init__(
self,
device: Union[str, None] = None,
categorical_features_indices: Optional[list[int]] = None,
model_path: str = "default",
n_estimators: int = 8,
softmax_temperature: float = 0.9,
average_before_softmax: bool = False,
ignore_pretraining_limits: bool = False,
inference_precision: Literal["autocast", "auto"] = "auto",
random_state: Optional[Union[int, np.random.RandomState, np.random.Generator]] = None,
inference_config: Optional[Dict] = None,
paper_version: bool = False,
) -> None:
self.device = device
self.categorical_features_indices = categorical_features_indices
if categorical_features_indices is not None:
warnings.warn(
"categorical_features_indices is not supported in the client version of TabPFN and will be ignored",
UserWarning,
stacklevel=2,
)
super().__init__(
model_path=model_path,
n_estimators=n_estimators,
softmax_temperature=softmax_temperature,
average_before_softmax=average_before_softmax,
ignore_pretraining_limits=ignore_pretraining_limits,
inference_precision=inference_precision,
random_state=random_state,
inference_config=inference_config,
paper_version=paper_version,
)

def get_params(self, deep: bool = True) -> Dict[str, Any]:
"""
Return parameters for this estimator.
"""
params = super().get_params(deep=deep)
params.pop("device")
params.pop("categorical_features_indices")
return params

return TabPFNClassifierWrapper, TabPFNRegressorWrapper

except ImportError:
raise ImportError(
Expand All @@ -76,4 +164,4 @@ def __init__(self, *args, device=None, **kwargs):
)


TabPFNClassifier, TabPFNRegressor, PreprocessorConfig = get_tabpfn_models()
TabPFNClassifier, TabPFNRegressor = get_tabpfn_models()

0 comments on commit 44ba312

Please sign in to comment.