|
1 | 1 | """ |
2 | | -========================================================= |
3 | | -OOB Errors for Random Forests and Extra Trees Classifiers |
4 | | -========================================================= |
| 2 | +============================= |
| 3 | +OOB Errors for Random Forests |
| 4 | +============================= |
5 | 5 |
|
6 | | -The ``RandomForestClassifier`` and ``ExtraTreesClasifier`` are trained using |
7 | | -*bootstrap aggregation*. During training, each new tree is fit from a |
8 | | -bootstrap sample of the training observations :math:`z_i = (x_i, y_i)`. The |
9 | | -*out-of-bag* (OOB) error is the average prediction error for each :math:`z_i` |
10 | | -from trees that do not contain :math:`z_i` in their respective bootstrap |
11 | | -sample. This allows models to be simultaneously fit and cross-validated [1]. |
| 6 | +The ``RandomForestClassifier`` is trained using *bootstrap aggregation*. During |
| 7 | +training, each new tree is fit from a bootstrap sample of the training |
| 8 | +observations :math:`z_i = (x_i, y_i)`. The *out-of-bag* (OOB) error is the |
| 9 | +average prediction error for each :math:`z_i` from trees that do not contain |
| 10 | +:math:`z_i` in their respective bootstrap sample. This allows models to be |
| 11 | +simultaneously fit and validated [1]. |
12 | 12 |
|
13 | 13 | The example below demonstrates how the OOB error can be measured at the |
14 | | -inclusion of each new tree whilst fitting ``RandomForestClassifier`` and |
15 | | -``ExtraTreesClassifier`` models. The subsequent plot enables the practitioner |
16 | | -to approximate the error stabilization point of each model at which training |
17 | | -can be halted. |
| 14 | +inclusion of each new tree whilst fitting ``RandomForestClassifier`` models. |
| 15 | +The subsequent plot enables the practitioner to approximate the error |
| 16 | +stabilization point of each model at which training can be halted. |
18 | 17 |
|
19 | 18 | .. [1] T. Hastie, R. Tibshirani and J. Friedman, "Elements of Statistical |
20 | | - Learning Ed. 2", Springer, 2009. |
| 19 | + Learning Ed. 2", p592-593, Springer, 2009. |
21 | 20 |
|
22 | 21 | """ |
23 | 22 | import matplotlib.pyplot as plt |
|
37 | 36 | RANDOM_STATE = 123 |
38 | 37 |
|
39 | 38 | # Generate a binary classification dataset. |
40 | | -X, y = make_classification(n_samples=500, n_features=30, |
41 | | - n_clusters_per_class=1, |
| 39 | +X, y = make_classification(n_samples=500, n_features=25, |
| 40 | + n_clusters_per_class=1, n_informative=15, |
42 | 41 | random_state=RANDOM_STATE) |
43 | 42 |
|
44 | 43 | # NOTE: Setting the `warm_start` construction parameter to `True` disables |
45 | 44 | # support for paralellised ensembles but is necessary for tracking the OOB |
46 | 45 | # error trajectory during training. |
47 | 46 | ensemble_clfs = [ |
48 | | - ("RandomForestClassifier, max_features='auto'", |
| 47 | + ("RandomForestClassifier, max_features='sqrt'", |
49 | 48 | RandomForestClassifier(warm_start=True, oob_score=True, |
50 | | - max_features="auto", |
| 49 | + max_features="sqrt", |
51 | 50 | random_state=RANDOM_STATE)), |
52 | | - ("RandomForestClassifier, max_features=2", |
53 | | - RandomForestClassifier(warm_start=True, max_features=2, |
| 51 | + ("RandomForestClassifier, max_features='log2'", |
| 52 | + RandomForestClassifier(warm_start=True, max_features='log2', |
54 | 53 | oob_score=True, |
55 | 54 | random_state=RANDOM_STATE)), |
56 | | - ("ExtraTreesClassifier, max_features='auto'", |
57 | | - ExtraTreesClassifier(warm_start=True, max_features="auto", |
58 | | - oob_score=True, bootstrap=True, |
59 | | - random_state=RANDOM_STATE)), |
60 | | - ("ExtraTreesClassifier, max_features=2", |
61 | | - ExtraTreesClassifier(warm_start=True, max_features=2, |
62 | | - oob_score=True, bootstrap=True, |
63 | | - random_state=RANDOM_STATE)) |
| 55 | + ("RandomForestClassifier, max_features=None", |
| 56 | + RandomForestClassifier(warm_start=True, max_features=None, |
| 57 | + oob_score=True, |
| 58 | + random_state=RANDOM_STATE)) |
64 | 59 | ] |
65 | 60 |
|
66 | 61 | # Map a classifier name to a list of (<n_estimators>, <error rate>) pairs. |
67 | 62 | error_rate = OrderedDict((label, []) for label, _ in ensemble_clfs) |
68 | 63 |
|
69 | 64 | # Range of `n_estimators` values to explore. |
70 | 65 | min_estimators = 15 |
71 | | -max_estimators = 150 |
| 66 | +max_estimators = 175 |
72 | 67 |
|
73 | 68 | for label, clf in ensemble_clfs: |
74 | 69 | for i in range(min_estimators, max_estimators + 1): |
|
0 commit comments