Skip to content

Commit 524daee

Browse files
committed
Merge branch 'bagging'
Fixes scikit-learn#2375.
2 parents c760a2f + 9a90e13 commit 524daee

File tree

11 files changed

+1638
-112
lines changed

11 files changed

+1638
-112
lines changed

doc/modules/classes.rst

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -298,15 +298,17 @@ Samples generator
298298
:toctree: generated/
299299
:template: class.rst
300300

301-
ensemble.RandomForestClassifier
302-
ensemble.RandomTreesEmbedding
303-
ensemble.RandomForestRegressor
304-
ensemble.ExtraTreesClassifier
305-
ensemble.ExtraTreesRegressor
306301
ensemble.AdaBoostClassifier
307302
ensemble.AdaBoostRegressor
303+
ensemble.BaggingClassifier
304+
ensemble.BaggingRegressor
305+
ensemble.ExtraTreesClassifier
306+
ensemble.ExtraTreesRegressor
308307
ensemble.GradientBoostingClassifier
309308
ensemble.GradientBoostingRegressor
309+
ensemble.RandomForestClassifier
310+
ensemble.RandomTreesEmbedding
311+
ensemble.RandomForestRegressor
310312

311313
.. autosummary::
312314
:toctree: generated/

doc/modules/ensemble.rst

Lines changed: 77 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,93 @@ Ensemble methods
77
.. currentmodule:: sklearn.ensemble
88

99
The goal of **ensemble methods** is to combine the predictions of several
10-
models built with a given learning algorithm in order to improve
11-
generalizability / robustness over a single model.
10+
base estimators built with a given learning algorithm in order to improve
11+
generalizability / robustness over a single estimator.
1212

1313
Two families of ensemble methods are usually distinguished:
1414

1515
- In **averaging methods**, the driving principle is to build several
16-
models independently and then to average their predictions. On average,
17-
the combined model is usually better than any of the single model
18-
because its variance is reduced.
16+
estimators independently and then to average their predictions. On average,
17+
the combined estimator is usually better than any of the single base
18+
estimator because its variance is reduced.
1919

20-
**Examples:** Bagging methods, :ref:`Forests of randomized trees <forest>`...
20+
**Examples:** :ref:`Bagging methods <bagging>`, :ref:`Forests of randomized trees <forest>`, ...
2121

22-
- By contrast, in **boosting methods**, models are built sequentially and one
23-
tries to reduce the bias of the combined model. The motivation is to combine
24-
several weak models to produce a powerful ensemble.
22+
- By contrast, in **boosting methods**, base estimators are built sequentially
23+
and one tries to reduce the bias of the combined estimator. The motivation is
24+
to combine several weak models to produce a powerful ensemble.
2525

2626
**Examples:** :ref:`AdaBoost <adaboost>`, :ref:`Gradient Tree Boosting <gradient_boosting>`, ...
2727

2828

29+
.. _bagging:
30+
31+
Bagging meta-estimator
32+
======================
33+
34+
In ensemble algorithms, bagging methods form a class of algorithms which build
35+
several instances of a black-box estimator on random subsets of the original
36+
training set and then aggregate their individual predictions to form a final
37+
prediction. These methods are used as a way to reduce the variance of a base
38+
estimator (e.g., a decision tree), by introducing randomization into its
39+
construction procedure and then making an ensemble out of it. In many cases,
40+
bagging methods constitute a very simple way to improve with respect to a
41+
single model, without making it necessary to adapt the underlying base
42+
algorithm. As they provide a way to reduce overfitting, bagging methods work
43+
best with strong and complex models (e.g., fully developed decision trees), in
44+
contrast with boosting methods which usually work best with weak models (e.g.,
45+
shallow decision trees).
46+
47+
Bagging methods come in many flavours but mostly differ from each other by the
48+
way they draw random subsets of the training set:
49+
50+
* When random subsets of the dataset are drawn as random subsets of the
51+
samples, then this algorithm is known as Pasting [B1999]_.
52+
53+
* When samples are drawn with replacement, then the method is known as
54+
Bagging [B1996]_.
55+
56+
* When random subsets of the dataset are drawn as random subsets of
57+
the features, then the method is known as Random Subspaces [H1998]_.
58+
59+
* Finally, when base estimators are built on subsets of both samples and
60+
features, then the method is known as Random Patches [LG2012]_.
61+
62+
In scikit-learn, bagging methods are offered as a unified
63+
:class:`BaggingClassifier` meta-estimator (resp. :class:`BaggingRegressor`),
64+
taking as input a user-specified base estimator along with parameters
65+
specifying the strategy to draw random subsets. In particular, ``max_samples``
66+
and ``max_features`` control the size of the subsets (in terms of samples and
67+
features), while ``bootstrap`` and ``bootstrap_features`` control whether
68+
samples and features are drawn with or without replacement. As an example, the
69+
snippet below illustrates how to instantiate a bagging ensemble of
70+
:class:`KNeighborsClassifier` base estimators, each built on random subsets of
71+
50% of the samples and 50% of the features.
72+
73+
>>> from sklearn.ensemble import BaggingClassifier
74+
>>> from sklearn.neighbors import KNeighborsClassifier
75+
>>> bagging = BaggingClassifier(KNeighborsClassifier(),
76+
... max_samples=0.5, max_features=0.5)
77+
78+
.. topic:: Examples:
79+
80+
* :ref:`example_ensemble_plot_bias_variance.py`
81+
82+
.. topic:: References
83+
84+
.. [B1999] L. Breiman, "Pasting small votes for classification in large
85+
databases and on-line", Machine Learning, 36(1), 85-103, 1999.
86+
87+
.. [B1996] L. Breiman, "Bagging predictors", Machine Learning, 24(2),
88+
123-140, 1996.
89+
90+
.. [H1998] T. Ho, "The random subspace method for constructing decision
91+
forests", Pattern Analysis and Machine Intelligence, 20(8), 832-844,
92+
1998.
93+
94+
.. [LG2012] G. Louppe and P. Geurts, "Ensembles on Random Patches",
95+
Machine Learning and Knowledge Discovery in Databases, 346-361, 2012.
96+
2997
.. _forest:
3098

3199
Forests of randomized trees
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
"""
2+
============================================================
3+
Single estimator versus bagging: bias-variance decomposition
4+
============================================================
5+
6+
This example illustrates and compares the bias-variance decomposition of the
7+
expected mean squared error of a single estimator against a bagging ensemble.
8+
9+
In regression, the expected mean squared error of an estimator can be
10+
decomposed in terms of bias, variance and noise. On average over datasets of
11+
the regression problem, the bias term measures the average amount by which the
12+
predictions of the estimator differ from the predictions of the best possible
13+
estimator for the problem (i.e., the Bayes model). The variance term measures
14+
the variability of the predictions of the estimator when fit over different
15+
instances LS of the problem. Finally, the noise measures the irreducible part
16+
of the error which is due the variability in the data.
17+
18+
The upper left figure illustrates the predictions (in dark red) of a single
19+
decision tree trained over a random dataset LS (the blue dots) of a toy 1d
20+
regression problem. It also illustrates the predictions (in light red) of other
21+
single decision trees trained over other (and different) randomly drawn
22+
instances LS of the problem. Intuitively, the variance term here corresponds to
23+
the width of the beam of predictions (in light red) of the individual
24+
estimators. The larger the variance, the more sensitive are the predictions for
25+
`x` to small changes in the training set. The bias term corresponds to the
26+
difference between the average prediction of the estimator (in cyan) and the
27+
best possible model (in dark blue). On this problem, we can thus observe that
28+
the bias is quite low (both the cyan and the blue curves are close to each
29+
other) while the variance is large (the red beam is rather wide).
30+
31+
The lower left figure plots the pointwise decomposition of the expected mean
32+
squared error of a single decision tree. It confirms that the bias term (in
33+
blue) is low while the variance is large (in green). It also illustrates the
34+
noise part of the error which, as expected, appears to be constant and around
35+
`0.01`.
36+
37+
The right figures correspond to the same plots but using instead a bagging
38+
ensemble of decision trees. In both figures, we can observe that the bias term
39+
is larger than in the previous case. In the upper right figure, the difference
40+
between the average prediction (in cyan) and the best possible model is larger
41+
(e.g., notice the offset around `x=2`). In the lower right figure, the bias
42+
curve is also slightly higher than in the lower left figure. In terms of
43+
variance however, the beam of predictions is narrower, which suggests that the
44+
variance is lower. Indeed, as the lower right figure confirms, the variance
45+
term (in green) is lower than for single decision trees. Overall, the bias-
46+
variance decomposition is therefore no longer the same. The tradeoff is better
47+
for bagging: averaging several decision trees fit on bootstrap copies of the
48+
dataset slightly increases the bias term but allows for a larger reduction of
49+
the variance, which results in a lower overall mean squared error (compare the
50+
red curves int the lower figures). The script output also confirms this
51+
intuition. The total error of the bagging ensemble is lower than the total
52+
error of a single decision tree, and this difference indeed mainly stems from a
53+
reduced variance.
54+
55+
For further details on bias-variance decomposition, see section 7.3 of [1]_.
56+
57+
References
58+
----------
59+
60+
.. [1] T. Hastie, R. Tibshirani and J. Friedman,
61+
"Elements of Statistical Learning", Springer, 2009.
62+
63+
"""
64+
print(__doc__)
65+
66+
# Author: Gilles Louppe <[email protected]>
67+
# License: BSD 3 clause
68+
69+
import numpy as np
70+
from matplotlib import pyplot as plt
71+
72+
from sklearn.ensemble import BaggingRegressor
73+
from sklearn.tree import DecisionTreeRegressor
74+
75+
# Settings
76+
n_repeat = 50 # Number of iterations for computing expectations
77+
n_train = 50 # Size of the training set
78+
n_test = 1000 # Size of the test set
79+
noise = 0.1 # Standard deviation of the noise
80+
np.random.seed(0)
81+
82+
# Change this for exploring the bias-variance decomposition of other
83+
# estimators. This should work well for estimators with high variance (e.g.,
84+
# decision trees or KNN), but poorly for estimators with low variance (e.g.,
85+
# linear models).
86+
estimators = [("Tree", DecisionTreeRegressor()),
87+
("Bagging(Tree)", BaggingRegressor(DecisionTreeRegressor()))]
88+
89+
n_estimators = len(estimators)
90+
91+
# Generate data
92+
def f(x):
93+
x = x.ravel()
94+
95+
return np.exp(-x ** 2) + 1.5 * np.exp(-(x - 2) ** 2)
96+
97+
def generate(n_samples, noise, n_repeat=1):
98+
X = np.random.rand(n_samples) * 10 - 5
99+
X = np.sort(X)
100+
101+
if n_repeat == 1:
102+
y = f(X) + np.random.normal(0.0, noise, n_samples)
103+
else:
104+
y = np.zeros((n_samples, n_repeat))
105+
106+
for i in range(n_repeat):
107+
y[:, i] = f(X) + np.random.normal(0.0, noise, n_samples)
108+
109+
X = X.reshape((n_samples, 1))
110+
111+
return X, y
112+
113+
X_train = []
114+
y_train = []
115+
116+
for i in range(n_repeat):
117+
X, y = generate(n_samples=n_train, noise=noise)
118+
X_train.append(X)
119+
y_train.append(y)
120+
121+
X_test, y_test = generate(n_samples=n_test, noise=noise, n_repeat=n_repeat)
122+
123+
# Loop over estimators to compare
124+
for n, (name, estimator) in enumerate(estimators):
125+
# Compute predictions
126+
y_predict = np.zeros((n_test, n_repeat))
127+
128+
for i in xrange(n_repeat):
129+
estimator.fit(X_train[i], y_train[i])
130+
y_predict[:, i] = estimator.predict(X_test)
131+
132+
# Bias^2 + Variance + Noise decomposition of the mean squared error
133+
y_error = np.zeros(n_test)
134+
135+
for i in range(n_repeat):
136+
for j in range(n_repeat):
137+
y_error += (y_test[:, j] - y_predict[:, i]) ** 2
138+
139+
y_error /= (n_repeat * n_repeat)
140+
141+
y_noise = np.var(y_test, axis=1)
142+
y_bias = (f(X_test) - np.mean(y_predict, axis=1)) ** 2
143+
y_var = np.var(y_predict, axis=1)
144+
145+
print("{0}: {1:.4f} (error) = {2:.4f} (bias^2) "
146+
" + {3:.4f} (var) + {4:.4f} (noise)".format(name,
147+
np.mean(y_error),
148+
np.mean(y_bias),
149+
np.mean(y_var),
150+
np.mean(y_noise)))
151+
152+
# Plot figures
153+
plt.subplot(2, n_estimators, n + 1)
154+
plt.plot(X_test, f(X_test), "b", label="$f(x)$")
155+
plt.plot(X_train[0], y_train[0], ".b", label="LS ~ $y = f(x)+noise$")
156+
157+
for i in range(n_repeat):
158+
if i == 0:
159+
plt.plot(X_test, y_predict[:, i], "r", label="$\^y(x)$")
160+
else:
161+
plt.plot(X_test, y_predict[:, i], "r", alpha=0.05)
162+
163+
plt.plot(X_test, np.mean(y_predict, axis=1), "c",
164+
label="$\mathbb{E}_{LS} \^y(x)$")
165+
166+
plt.xlim([-5, 5])
167+
plt.title(name)
168+
169+
if n == 0:
170+
plt.legend(loc="upper left", prop={"size": 11})
171+
172+
plt.subplot(2, n_estimators, n_estimators + n + 1)
173+
plt.plot(X_test, y_error, "r", label="$error(x)$")
174+
plt.plot(X_test, y_bias, "b", label="$bias^2(x)$"),
175+
plt.plot(X_test, y_var, "g", label="$variance(x)$"),
176+
plt.plot(X_test, y_noise, "c", label="$noise(x)$")
177+
178+
plt.xlim([-5, 5])
179+
plt.ylim([0, 0.1])
180+
181+
if n == 0:
182+
plt.legend(loc="upper left", prop={"size": 11})
183+
184+
plt.show()

sklearn/ensemble/__init__.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,24 @@
99
from .forest import RandomTreesEmbedding
1010
from .forest import ExtraTreesClassifier
1111
from .forest import ExtraTreesRegressor
12+
from .bagging import BaggingClassifier
13+
from .bagging import BaggingRegressor
1214
from .weight_boosting import AdaBoostClassifier
1315
from .weight_boosting import AdaBoostRegressor
1416
from .gradient_boosting import GradientBoostingClassifier
1517
from .gradient_boosting import GradientBoostingRegressor
1618

19+
from . import bagging
1720
from . import forest
1821
from . import weight_boosting
1922
from . import gradient_boosting
2023
from . import partial_dependence
2124

22-
__all__ = ["BaseEnsemble", "RandomForestClassifier", "RandomForestRegressor",
25+
__all__ = ["BaseEnsemble",
26+
"RandomForestClassifier", "RandomForestRegressor",
2327
"RandomTreesEmbedding", "ExtraTreesClassifier",
24-
"ExtraTreesRegressor", "GradientBoostingClassifier",
28+
"ExtraTreesRegressor", "BaggingClassifier",
29+
"BaggingRegressor", "GradientBoostingClassifier",
2530
"GradientBoostingRegressor", "AdaBoostClassifier",
26-
"AdaBoostRegressor", "forest", "gradient_boosting",
31+
"AdaBoostRegressor", "bagging", "forest", "gradient_boosting",
2732
"partial_dependence", "weight_boosting"]

0 commit comments

Comments
 (0)