Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix phe and hpo with TabPFN API #9

Merged
merged 5 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/tabpfn_extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@
__version__ = "0.1.0.dev0"

from .utils import is_tabpfn
from .utils import TabPFNRegressor, TabPFNClassifier, PreprocessorConfig
from .utils import TabPFNRegressor, TabPFNClassifier

from . import utils_todo
15 changes: 7 additions & 8 deletions src/tabpfn_extensions/hpo/search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from hyperopt import hp
from pathlib import Path
from tabpfn_extensions import PreprocessorConfig


def enumerate_preprocess_transforms():
Expand Down Expand Up @@ -47,13 +46,13 @@ def enumerate_preprocess_transforms():
for global_transformer_name in [None, "svd"]:
transforms += [
[
PreprocessorConfig(
name=name,
global_transformer_name=global_transformer_name,
subsample_features=subsample_features,
categorical_name=categorical_name,
append_original=append_original,
)
{
"name": name,
"global_transformer_name": global_transformer_name,
"subsample_features": subsample_features,
"categorical_name": categorical_name,
"append_original": append_original,
}
for name in names
],
]
Expand Down
128 changes: 108 additions & 20 deletions src/tabpfn_extensions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
import os

import os
from typing import Any, Type, Tuple, Protocol

from typing import Any, Type, Tuple, Protocol, Literal, Optional, Union, Dict
from dataclasses import dataclass
from typing_extensions import override
from sklearn.base import BaseEstimator
import numpy as np
import warnings

class TabPFNEstimator(Protocol):
def fit(self, X: Any, y: Any) -> Any:
Expand All @@ -29,7 +33,6 @@ def is_tabpfn(estimator: Any) -> bool:
except (AttributeError, TypeError):
return False


from typing import Tuple, Type
import os
USE_TABPFN_LOCAL = os.getenv("USE_TABPFN_LOCAL", "true").lower() == "true"
Expand All @@ -39,9 +42,8 @@ def get_tabpfn_models() -> Tuple[Type, Type, Type]:
if USE_TABPFN_LOCAL:
try:
from tabpfn import TabPFNClassifier, TabPFNRegressor
from tabpfn.preprocessing import PreprocessorConfig

return TabPFNClassifier, TabPFNRegressor, PreprocessorConfig
return TabPFNClassifier, TabPFNRegressor
except ImportError:
pass

Expand All @@ -50,22 +52,108 @@ def get_tabpfn_models() -> Tuple[Type, Type, Type]:
TabPFNClassifier as ClientTabPFNClassifier,
TabPFNRegressor as ClientTabPFNRegressor,
)
from tabpfn_client.estimator import (
PreprocessorConfig as ClientPreprocessorConfig,
)

# Wrapper classes to add device parameter
class TabPFNClassifier(ClientTabPFNClassifier):
def __init__(self, *args, device=None, **kwargs):
super().__init__(*args, **kwargs)
# Ignoring the device parameter for now

class TabPFNRegressor(ClientTabPFNRegressor):
def __init__(self, *args, device=None, **kwargs):
super().__init__(*args, **kwargs)
# Ignoring the device parameter for now

return TabPFNClassifier, TabPFNRegressor, ClientPreprocessorConfig
# Wrapper classes to add device parameter
# we can't use *args because scikit-learn needs to know the parameters of the constructor
class TabPFNClassifierWrapper(ClientTabPFNClassifier):
def __init__(
self,
device: Union[str, None] = None,
categorical_features_indices: Optional[list[int]] = None,
model_path: str = "default",
n_estimators: int = 4,
softmax_temperature: float = 0.9,
balance_probabilities: bool = False,
average_before_softmax: bool = False,
ignore_pretraining_limits: bool = False,
inference_precision: Literal["autocast", "auto"] = "auto",
random_state: Optional[Union[int, np.random.RandomState, np.random.Generator]] = None,
inference_config: Optional[Dict] = None,
paper_version: bool = False,
) -> None:
self.device = device
#TODO: we should support this argument in the client version
self.categorical_features_indices = categorical_features_indices
if categorical_features_indices is not None:
warnings.warn(
"categorical_features_indices is not supported in the client version of TabPFN and will be ignored",
UserWarning,
stacklevel=2,
)
if "/" in model_path:
model_name = model_path.split("/")[-1].split("-")[-1].split(".")[0]
if model_name == "classifier":
model_name = "default"
self.model_path = model_name
else:
self.model_path = model_path

super().__init__(
model_path=self.model_path,
n_estimators=n_estimators,
softmax_temperature=softmax_temperature,
balance_probabilities=balance_probabilities,
average_before_softmax=average_before_softmax,
ignore_pretraining_limits=ignore_pretraining_limits,
inference_precision=inference_precision,
random_state=random_state,
inference_config=inference_config,
paper_version=paper_version,
)

def get_params(self, deep: bool = True) -> Dict[str, Any]:
"""Return parameters for this estimator."""
params = super().get_params(deep=deep)
params.pop("device")
params.pop("categorical_features_indices")
return params

class TabPFNRegressorWrapper(ClientTabPFNRegressor):
def __init__(
self,
device: Union[str, None] = None,
categorical_features_indices: Optional[list[int]] = None,
model_path: str = "default",
n_estimators: int = 8,
softmax_temperature: float = 0.9,
average_before_softmax: bool = False,
ignore_pretraining_limits: bool = False,
inference_precision: Literal["autocast", "auto"] = "auto",
random_state: Optional[Union[int, np.random.RandomState, np.random.Generator]] = None,
inference_config: Optional[Dict] = None,
paper_version: bool = False,
) -> None:
self.device = device
self.categorical_features_indices = categorical_features_indices
if categorical_features_indices is not None:
warnings.warn(
"categorical_features_indices is not supported in the client version of TabPFN and will be ignored",
UserWarning,
stacklevel=2,
)
super().__init__(
model_path=model_path,
n_estimators=n_estimators,
softmax_temperature=softmax_temperature,
average_before_softmax=average_before_softmax,
ignore_pretraining_limits=ignore_pretraining_limits,
inference_precision=inference_precision,
random_state=random_state,
inference_config=inference_config,
paper_version=paper_version,
)

def get_params(self, deep: bool = True) -> Dict[str, Any]:
"""
Return parameters for this estimator.
"""
params = super().get_params(deep=deep)
params.pop("device")
params.pop("categorical_features_indices")
return params

return TabPFNClassifierWrapper, TabPFNRegressorWrapper

except ImportError:
raise ImportError(
Expand All @@ -76,4 +164,4 @@ def __init__(self, *args, device=None, **kwargs):
)


TabPFNClassifier, TabPFNRegressor, PreprocessorConfig = get_tabpfn_models()
TabPFNClassifier, TabPFNRegressor = get_tabpfn_models()