Skip to content

Commit e038cb3

Browse files
author
Kian Ho
committed
Made more amendments according to PR feedback.
A number of further amendments to the plot_ensemble_oob.py example script were suggested in the PR thread and addressed accordingly: - The ExtraTreesClassifier models were removed from the example, since they don't use bootstrapping by default (but can be using bootstrap=True). - Included the OOB errors for RandomForestClassifier models with various max_features values. - Changed the sample datasets to make for a nicer looking plot. - Changed "cross-validated" to "validated" in the docstring. - Added the relevant page numbers to the Hastie et al. reference. - PEP8 compliance, fixed line > 80 chars.
1 parent b6ce322 commit e038cb3

File tree

1 file changed

+24
-29
lines changed

1 file changed

+24
-29
lines changed

examples/ensemble/plot_ensemble_oob.py

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
11
"""
2-
=========================================================
3-
OOB Errors for Random Forests and Extra Trees Classifiers
4-
=========================================================
2+
=============================
3+
OOB Errors for Random Forests
4+
=============================
55
6-
The ``RandomForestClassifier`` and ``ExtraTreesClasifier`` are trained using
7-
*bootstrap aggregation*. During training, each new tree is fit from a
8-
bootstrap sample of the training observations :math:`z_i = (x_i, y_i)`. The
9-
*out-of-bag* (OOB) error is the average prediction error for each :math:`z_i`
10-
from trees that do not contain :math:`z_i` in their respective bootstrap
11-
sample. This allows models to be simultaneously fit and cross-validated [1].
6+
The ``RandomForestClassifier`` is trained using *bootstrap aggregation*. During
7+
training, each new tree is fit from a bootstrap sample of the training
8+
observations :math:`z_i = (x_i, y_i)`. The *out-of-bag* (OOB) error is the
9+
average prediction error for each :math:`z_i` from trees that do not contain
10+
:math:`z_i` in their respective bootstrap sample. This allows models to be
11+
simultaneously fit and validated [1].
1212
1313
The example below demonstrates how the OOB error can be measured at the
14-
inclusion of each new tree whilst fitting ``RandomForestClassifier`` and
15-
``ExtraTreesClassifier`` models. The subsequent plot enables the practitioner
16-
to approximate the error stabilization point of each model at which training
17-
can be halted.
14+
inclusion of each new tree whilst fitting ``RandomForestClassifier`` models.
15+
The subsequent plot enables the practitioner to approximate the error
16+
stabilization point of each model at which training can be halted.
1817
1918
.. [1] T. Hastie, R. Tibshirani and J. Friedman, "Elements of Statistical
20-
Learning Ed. 2", Springer, 2009.
19+
Learning Ed. 2", p592-593, Springer, 2009.
2120
2221
"""
2322
import matplotlib.pyplot as plt
@@ -37,38 +36,34 @@
3736
RANDOM_STATE = 123
3837

3938
# Generate a binary classification dataset.
40-
X, y = make_classification(n_samples=500, n_features=30,
41-
n_clusters_per_class=1,
39+
X, y = make_classification(n_samples=500, n_features=25,
40+
n_clusters_per_class=1, n_informative=15,
4241
random_state=RANDOM_STATE)
4342

4443
# NOTE: Setting the `warm_start` construction parameter to `True` disables
4544
# support for paralellised ensembles but is necessary for tracking the OOB
4645
# error trajectory during training.
4746
ensemble_clfs = [
48-
("RandomForestClassifier, max_features='auto'",
47+
("RandomForestClassifier, max_features='sqrt'",
4948
RandomForestClassifier(warm_start=True, oob_score=True,
50-
max_features="auto",
49+
max_features="sqrt",
5150
random_state=RANDOM_STATE)),
52-
("RandomForestClassifier, max_features=2",
53-
RandomForestClassifier(warm_start=True, max_features=2,
51+
("RandomForestClassifier, max_features='log2'",
52+
RandomForestClassifier(warm_start=True, max_features='log2',
5453
oob_score=True,
5554
random_state=RANDOM_STATE)),
56-
("ExtraTreesClassifier, max_features='auto'",
57-
ExtraTreesClassifier(warm_start=True, max_features="auto",
58-
oob_score=True, bootstrap=True,
59-
random_state=RANDOM_STATE)),
60-
("ExtraTreesClassifier, max_features=2",
61-
ExtraTreesClassifier(warm_start=True, max_features=2,
62-
oob_score=True, bootstrap=True,
63-
random_state=RANDOM_STATE))
55+
("RandomForestClassifier, max_features=None",
56+
RandomForestClassifier(warm_start=True, max_features=None,
57+
oob_score=True,
58+
random_state=RANDOM_STATE))
6459
]
6560

6661
# Map a classifier name to a list of (<n_estimators>, <error rate>) pairs.
6762
error_rate = OrderedDict((label, []) for label, _ in ensemble_clfs)
6863

6964
# Range of `n_estimators` values to explore.
7065
min_estimators = 15
71-
max_estimators = 150
66+
max_estimators = 175
7267

7368
for label, clf in ensemble_clfs:
7469
for i in range(min_estimators, max_estimators + 1):

0 commit comments

Comments
 (0)