Skip to content

Commit

Permalink
Merge pull request #844 from rasbt/bootstrap
Browse files Browse the repository at this point in the history
use whole training set in .632 and .632+ bootstrap
  • Loading branch information
rasbt authored Sep 2, 2021
2 parents bf17ad5 + e4e2b4d commit e32aaf2
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 40 deletions.
7 changes: 4 additions & 3 deletions docs/sources/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@ The CHANGELOG for the current development version is available at
- Removes deprecated `res` argument from `plot_decision_regions`. ([#803](https://github.com/rasbt/mlxtend/pull/803))
- Adds a `title_fontsize` parameter to `plot_learning_curves` for controlling the title font size; also the plot style is now the matplotlib default. ([#818](https://github.com/rasbt/mlxtend/pull/818))
- Internal change using `'c': 'none'` instead of `'c': ''` in `mlxtend.plotting.plot_decision_regions`'s scatterplot highlights to stay compatible with Matplotlib 3.4 and newer. ([#822](https://github.com/rasbt/mlxtend/pull/822))
- Adds a `fontcolor_threshold` parameter to the `mlxtend.plotting.plot_confusion_matrix` function as an additional option for determining the font color cut-off manually. ([#825](https://github.com/rasbt/mlxtend/pull/825))
- The `frequent_patterns.association_rules` now raises a `ValueError` if an empty frequent itemset DataFrame is passed. ([#842](https://github.com/rasbt/mlxtend/pull/842))
- Adds a `fontcolor_threshold` parameter to the `mlxtend.plotting.plot_confusion_matrix` function as an additional option for determining the font color cut-off manually. ([#827](https://github.com/rasbt/mlxtend/pull/827))
- The `frequent_patterns.association_rules` now raises a `ValueError` if an empty frequent itemset DataFrame is passed. ([#843](https://github.com/rasbt/mlxtend/pull/843))
- The .632 and .632+ bootstrap method implemented in the `mlxtend.evaluate.bootstrap_point632_score` function now use the whole training set for the resubstitution weighting term instead of the internal training set that is a new bootstrap sample in each round. ([#844](https://github.com/rasbt/mlxtend/pull/844))

##### Bug Fixes

- Fixes a typo in the SequentialFeatureSelector documentation ([Issue #835](https://github.com/rasbt/mlxtend/issues/835) via [João Pedro Zanlorensi Cardoso](https://github.com/joaozanlorensi))
- Fixes a typo in the SequentialFeatureSelector documentation ([#835](https://github.com/rasbt/mlxtend/issues/835) via [João Pedro Zanlorensi Cardoso](https://github.com/joaozanlorensi))


### Version 0.18.0 (11/25/2020)
Expand Down
70 changes: 42 additions & 28 deletions docs/sources/user_guide/evaluate/bootstrap_point632_score.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
"\n",
"$$\\text{ACC}_{boot} = \\frac{1}{b} \\sum_{i=1}^b \\big(0.632 \\cdot \\text{ACC}_{h, i} + 0.368 \\cdot \\text{ACC}_{r, i}\\big), $$\n",
"\n",
"where $\\text{ACC}_{r, i}$ is the resubstitution accuracy, and $\\text{ACC}_{h, i}$ is the accuracy on the out-of-bag sample.\n",
"where $\\text{ACC}_{train}$ is the accuracy computed on the whole training set, and $\\text{ACC}_{h, i}$ is the accuracy on the out-of-bag sample.\n",
"\n",
"### .632+ Bootstrap\n",
"\n",
Expand All @@ -71,9 +71,9 @@
"\n",
"where *R* is the *relative overfitting rate*\n",
"\n",
"$$R = \\frac{(-1) \\times (\\text{ACC}_{h, i} - \\text{ACC}_{r, i})}{\\gamma - (1 -\\text{ACC}_{h, i})}.$$\n",
"$$R = \\frac{(-1) \\times (\\text{ACC}_{h, i} - \\text{ACC}_{train})}{\\gamma - (1 -\\text{ACC}_{h, i})}.$$\n",
"\n",
"(Since we are plugging $\\omega$ into the equation for computing $$ACC_{boot}$$ that we defined above, $$\\text{ACC}_{h, i}$$ and $\\text{ACC}_{r, i}$ still refer to the resubstitution and out-of-bag accuracy estimates in the *i*th bootstrap round, respectively.)\n",
"(Since we are plugging $\\omega$ into the equation for computing $$ACC_{boot}$$ that we defined above, $$\\text{ACC}_{h, i}$$ and $\\text{ACC}_{r, i}$ still refer to the out-of-bag accuracy in the *i*th bootstrap round and the whole training set accuracy, respectively.)\n",
"\n",
"Further, we need to determine the *no-information rate* $\\gamma$ in order to compute *R*. For instance, we can compute $\\gamma$ by fitting a model to a dataset that contains all possible combinations between samples $x_{i'}$ and target class labels $y_{i}$ — we pretend that the observations and class labels are independent:\n",
"\n",
Expand Down Expand Up @@ -121,8 +121,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy: 94.36%\n",
"95% Confidence interval: [88.46, 98.31]\n"
"Accuracy: 94.45%\n",
"95% Confidence interval: [87.71, 100.00]\n"
]
}
],
Expand All @@ -135,7 +135,7 @@
"iris = datasets.load_iris()\n",
"X = iris.data\n",
"y = iris.target\n",
"tree = DecisionTreeClassifier(random_state=0)\n",
"tree = DecisionTreeClassifier(random_state=123)\n",
"\n",
"# Model accuracy\n",
"scores = bootstrap_point632_score(tree, X, y, method='oob')\n",
Expand Down Expand Up @@ -165,8 +165,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy: 96.57%\n",
"95% Confidence interval: [92.37, 98.95]\n"
"Accuracy: 96.42%\n",
"95% Confidence interval: [92.41, 100.00]\n"
]
}
],
Expand All @@ -179,7 +179,7 @@
"iris = datasets.load_iris()\n",
"X = iris.data\n",
"y = iris.target\n",
"tree = DecisionTreeClassifier(random_state=0)\n",
"tree = DecisionTreeClassifier(random_state=123)\n",
"\n",
"# Model accuracy\n",
"scores = bootstrap_point632_score(tree, X, y)\n",
Expand Down Expand Up @@ -209,8 +209,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy: 96.28%\n",
"95% Confidence interval: [92.10, 98.90]\n"
"Accuracy: 96.29%\n",
"95% Confidence interval: [91.86, 98.92]\n"
]
}
],
Expand All @@ -223,7 +223,7 @@
"iris = datasets.load_iris()\n",
"X = iris.data\n",
"y = iris.target\n",
"tree = DecisionTreeClassifier(random_state=0)\n",
"tree = DecisionTreeClassifier(random_state=123)\n",
"\n",
"# Model accuracy\n",
"scores = bootstrap_point632_score(tree, X, y, method='.632+')\n",
Expand Down Expand Up @@ -255,21 +255,21 @@
"text": [
"## bootstrap_point632_score\n",
"\n",
"*bootstrap_point632_score(estimator, X, y, n_splits=200, method='.632', scoring_func=None, random_seed=None, clone_estimator=True)*\n",
"*bootstrap_point632_score(estimator, X, y, n_splits=200, method='.632', scoring_func=None, predict_proba=False, random_seed=None, clone_estimator=True)*\n",
"\n",
"Implementation of the .632 [1] and .632+ [2] bootstrap\n",
"for supervised learning\n",
" for supervised learning\n",
"\n",
"References:\n",
" References:\n",
"\n",
"- [1] Efron, Bradley. 1983. Estimating the Error Rate\n",
"of a Prediction Rule: Improvement on Cross-Validation.\n",
"Journal of the American Statistical Association\n",
"78 (382): 316. doi:10.2307/2288636.\n",
"- [2] Efron, Bradley, and Robert Tibshirani. 1997.\n",
"Improvements on Cross-Validation: The .632+ Bootstrap Method.\n",
"Journal of the American Statistical Association\n",
"92 (438): 548. doi:10.2307/2965703.\n",
" - [1] Efron, Bradley. 1983. \"Estimating the Error Rate\n",
" of a Prediction Rule: Improvement on Cross-Validation.\"\n",
" Journal of the American Statistical Association\n",
" 78 (382): 316. doi:10.2307/2288636.\n",
" - [2] Efron, Bradley, and Robert Tibshirani. 1997.\n",
" \"Improvements on Cross-Validation: The .632+ Bootstrap Method.\"\n",
" Journal of the American Statistical Association\n",
" 92 (438): 548. doi:10.2307/2965703.\n",
"\n",
"**Parameters**\n",
"\n",
Expand Down Expand Up @@ -316,6 +316,19 @@
" if the estimator is a regressor.\n",
"\n",
"\n",
"- `predict_proba` : bool\n",
"\n",
" Whether to use the `predict_proba` function for the\n",
" `estimator` argument. This is to be used in conjunction\n",
" with `scoring_func` which takes in probability values\n",
" instead of actual predictions.\n",
" For example, if the scoring_func is\n",
" :meth:`sklearn.metrics.roc_auc_score`, then use\n",
" `predict_proba=True`.\n",
" Note that this requires `estimator` to have\n",
" `predict_proba` method implemented.\n",
"\n",
"\n",
"- `random_seed` : int (default=None)\n",
"\n",
" If int, random_seed is the seed used by\n",
Expand All @@ -336,7 +349,7 @@
"\n",
"**Examples**\n",
"\n",
"\n",
"```\n",
" >>> from sklearn import datasets, linear_model\n",
" >>> from mlxtend.evaluate import bootstrap_point632_score\n",
" >>> iris = datasets.load_iris()\n",
Expand All @@ -352,8 +365,9 @@
" >>> print('95%% Confidence interval: [%.2f, %.2f]' % (lower, upper))\n",
" 95% Confidence interval: [0.90, 0.98]\n",
"\n",
"For more usage examples, please see\n",
"[http://rasbt.github.io/mlxtend/user_guide/evaluate/bootstrap_point632_score/](http://rasbt.github.io/mlxtend/user_guide/evaluate/bootstrap_point632_score/)\n",
" For more usage examples, please see\n",
" http://rasbt.github.io/mlxtend/user_guide/evaluate/bootstrap_point632_score/\n",
"```\n",
"\n",
"\n"
]
Expand All @@ -369,7 +383,7 @@
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -383,7 +397,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.1"
"version": "3.9.6"
},
"toc": {
"nav_menu": {},
Expand Down
31 changes: 28 additions & 3 deletions mlxtend/evaluate/bootstrap_point632.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,13 +185,32 @@ def bootstrap_point632_score(estimator, X, y, n_splits=200,
oob = BootstrapOutOfBag(n_splits=n_splits, random_seed=random_seed)
scores = np.empty(dtype=np.float, shape=(n_splits,))
cnt = 0

for train, test in oob.split(X):
cloned_est.fit(X[train], y[train])

# get the prediction probability
# for binary class uses the last column
predicted_test_val = predict_func(X[test])
predicted_train_val = predict_func(X[train])

if method in ('.632', '.632+'):
# predictions on the internal training set:
# predicted_train_val = predict_func(X[train])

# compute training error on the whole training set as reported in
# the original .632 boostrap paper
# in Eq (6.12) in
# "Estimating the Error Rate of a Prediction Rule: Improvement
# on Cross-Validation"
# by B. Efron, 1983, https://doi.org/10.2307/2288636
# Also see the discussion at
# https://github.com/rasbt/mlxtend/discussions/828
#
# This also applies to the .632+ estimate in the paper
# "Improvements on Cross-Validation: The .632+ Bootstrap Method"
# https://www.tandfonline.com/doi/abs/10.1080/01621459.1997.10474007
predicted_train_val = predict_func(X)

if predict_proba:
len_uniq = np.unique(y)

Expand All @@ -206,13 +225,19 @@ def bootstrap_point632_score(estimator, X, y, n_splits=200,

else:
test_err = 1 - test_acc
train_err = 1 - scoring_func(y[train], predicted_train_val)

# training error on the whole training set as mentioned in the
# previous comment above
train_err = 1 - scoring_func(
y, predicted_train_val)

if method == '.632+':
gamma = 1 - (no_information_rate(
y,
cloned_est.predict(X),
scoring_func))
R = (test_err - train_err) / (gamma - train_err)
R = (test_err - train_err) / (
gamma - train_err)
weight = 0.632 / (1 - 0.368*R)

else:
Expand Down
12 changes: 6 additions & 6 deletions mlxtend/evaluate/tests/test_bootstrap_point632.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_defaults():
scores = bootstrap_point632_score(lr, X, y, random_seed=123)
acc = np.mean(scores)
assert len(scores == 200)
assert np.round(acc, 5) == 0.95306, np.round(acc, 5)
assert np.round(acc, 5) == 0.95117, np.round(acc, 5)


def test_oob():
Expand All @@ -58,14 +58,14 @@ def test_632():
method='.632')
acc = np.mean(scores)
assert len(scores == 200)
assert np.round(acc, 5) == 0.96629, np.round(acc, 5)
assert np.round(acc, 5) == 0.95914, np.round(acc, 5)

tree2 = DecisionTreeClassifier(random_state=123, max_depth=1)
scores = bootstrap_point632_score(tree2, X, y, random_seed=123,
method='.632')
acc = np.mean(scores)
assert len(scores == 200)
assert np.round(acc, 5) == 0.65512, np.round(acc, 5)
assert np.round(acc, 5) == 0.64355, np.round(acc, 5)


def test_632plus():
Expand All @@ -74,14 +74,14 @@ def test_632plus():
method='.632+')
acc = np.mean(scores)
assert len(scores == 200)
assert np.round(acc, 5) == 0.9649, np.round(acc, 5)
assert np.round(acc, 5) == 0.95855, np.round(acc, 5)

tree2 = DecisionTreeClassifier(random_state=123, max_depth=1)
scores = bootstrap_point632_score(tree2, X, y, random_seed=123,
method='.632+')
acc = np.mean(scores)
assert len(scores == 200)
assert np.round(acc, 5) == 0.64831, np.round(acc, 5)
assert np.round(acc, 5) == 0.64078, np.round(acc, 5)


def test_custom_accuracy():
Expand All @@ -95,7 +95,7 @@ def accuracy2(targets, predictions):
scoring_func=accuracy2)
acc = np.mean(scores)
assert len(scores == 200)
assert np.round(acc, 5) == 0.95306, np.round(acc, 5)
assert np.round(acc, 5) == 0.95117, np.round(acc, 5)


def test_invalid_splits():
Expand Down

0 comments on commit e32aaf2

Please sign in to comment.