Skip to content

Commit

Permalink
Remove distutils, drop support for old sklearn
Browse files Browse the repository at this point in the history
  • Loading branch information
maxnoe committed Jan 10, 2024
1 parent dfd6e7b commit 4606b8c
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 50 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v4
Expand Down
70 changes: 22 additions & 48 deletions funfolding/binning/tree_sklearn_based.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np
import copy
from distutils.version import StrictVersion

import sklearn
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
Expand All @@ -10,8 +9,6 @@
import warnings


old_sklean = StrictVersion("0.19.0") >= StrictVersion(sklearn.__version__)


def __sample_uniform__(y, sample_weight=None, random_state=None):
"""Function used to sample a uniform distribution from a binned y.
Expand Down Expand Up @@ -298,28 +295,16 @@ def __init__(self,
self.uniform = uniform

if regression:
if old_sklean:
if min_impurity_decrease != 0.:
warnings.warn('min_impurity_decrease is supported '
' only in sklearn version >= 0.19.0')
self.tree = DecisionTreeRegressor(
max_depth=max_depth,
min_samples_split=min_samples_split,
min_samples_leaf=min_samples_leaf,
max_leaf_nodes=max_leaf_nodes,
max_features=max_features,
min_weight_fraction_leaf=min_weight_fraction_leaf,
random_state=random_state)
else:
self.tree = DecisionTreeRegressor(
max_depth=max_depth,
min_samples_split=min_samples_split,
min_samples_leaf=min_samples_leaf,
max_leaf_nodes=max_leaf_nodes,
max_features=max_features,
min_impurity_decrease=min_impurity_decrease,
min_weight_fraction_leaf=min_weight_fraction_leaf,
random_state=random_state)
self.tree = DecisionTreeRegressor(
max_depth=max_depth,
min_samples_split=min_samples_split,
min_samples_leaf=min_samples_leaf,
max_leaf_nodes=max_leaf_nodes,
max_features=max_features,
min_impurity_decrease=min_impurity_decrease,
min_weight_fraction_leaf=min_weight_fraction_leaf,
random_state=random_state,
)
if boosted in ['linear', 'square', 'exponential']:
self.boosted = AdaBoostRegressor(
base_estimator=self.tree,
Expand All @@ -336,35 +321,24 @@ def __init__(self,
else:
self.boosted = None
else:
if old_sklean:
if min_impurity_decrease != 0.:
warnings.warn('min_impurity_decrease is supported '
' only in sklearn version >= 0.19.0')
self.tree = DecisionTreeClassifier(
max_depth=max_depth,
min_samples_split=min_samples_split,
min_samples_leaf=min_samples_leaf,
max_leaf_nodes=max_leaf_nodes,
max_features=max_features,
min_weight_fraction_leaf=min_weight_fraction_leaf,
random_state=random_state)
else:
self.tree = DecisionTreeClassifier(
max_depth=max_depth,
min_samples_split=min_samples_split,
min_samples_leaf=min_samples_leaf,
max_leaf_nodes=max_leaf_nodes,
max_features=max_features,
min_weight_fraction_leaf=min_weight_fraction_leaf,
min_impurity_decrease=min_impurity_decrease,
random_state=random_state)
self.tree = DecisionTreeClassifier(
max_depth=max_depth,
min_samples_split=min_samples_split,
min_samples_leaf=min_samples_leaf,
max_leaf_nodes=max_leaf_nodes,
max_features=max_features,
min_weight_fraction_leaf=min_weight_fraction_leaf,
min_impurity_decrease=min_impurity_decrease,
random_state=random_state,
)
if boosted in ['SAMME', 'SAMME.R']:
self.boosted = AdaBoostClassifier(
base_estimator=self.tree,
n_estimators=n_estimators,
learning_rate=learning_rate,
algorithm=boosted,
random_state=random_state)
random_state=random_state,
)
elif boosted is not None:
raise ValueError(
'\'boosted\' should be None for no boosting '
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
'matplotlib',
'numpy',
'pymc3',
'scikit-learn>=0.18.1',
'scikit-learn>=0.19.0',
'scipy',
'six>=1.1',
],
Expand Down

0 comments on commit 4606b8c

Please sign in to comment.