Skip to content

Commit 532c54c

Browse files
committed
Merge branch 'master' of github.com:scikit-learn/scikit-learn into tree-mo
Conflicts: sklearn/tree/_tree.c
2 parents dc8e65a + 1d4c087 commit 532c54c

File tree

8 files changed

+390
-41
lines changed

8 files changed

+390
-41
lines changed

doc/modules/ensemble.rst

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,8 @@ amount of time (e.g., on large datasets).
165165

166166
* :ref:`example_ensemble_plot_forest_iris.py`
167167
* :ref:`example_ensemble_plot_forest_importances_faces.py`
168-
* :ref:`example_ensemble_plot_forest_multioutput.py`
168+
* :ref:`example_ensemble_plot_forest_multioutput.py`
169+
169170

170171
.. topic:: References
171172

@@ -177,9 +178,6 @@ amount of time (e.g., on large datasets).
177178
trees", Machine Learning, 63(1), 3-42, 2006.
178179
179180
180-
.. _gradient_boosting:
181-
182-
183181
Feature importance evaluation
184182
-----------------------------
185183

@@ -219,6 +217,8 @@ the matching feature to the prediction function.
219217
* :ref:`example_ensemble_plot_forest_importances.py`
220218

221219

220+
.. _gradient_boosting:
221+
222222
Gradient Tree Boosting
223223
======================
224224

@@ -284,11 +284,10 @@ that controls overfitting via :ref:`shrinkage <gradient_boosting_shrinkage>`.
284284
Regression
285285
----------
286286

287-
:class:`GradientBoostingRegressor` supports a number of different loss
288-
functions for regression which can be specified via the argument
289-
``loss``. Currently, supported are least squares (``loss='ls'``) and
290-
least absolute deviation (``loss='lad'``), which is more robust w.r.t.
291-
outliers. See [F2001]_ for detailed information.
287+
:class:`GradientBoostingRegressor` supports a number of
288+
:ref:`different loss functions <gradient_boosting_loss>`
289+
for regression which can be specified via the argument
290+
``loss`` which defaults to least squares (``'ls'``).
292291

293292
::
294293

@@ -378,6 +377,7 @@ Where the step length :math:`\gamma_m` is choosen using line search:
378377
The algorithms for regression and classification
379378
only differ in the concrete loss function used.
380379

380+
.. _gradient_boosting_loss:
381381

382382
Loss Functions
383383
...............
@@ -393,6 +393,13 @@ the parameter ``loss``:
393393
* Least absolute deviation (``'lad'``): A robust loss function for
394394
regression. The initial model is given by the median of the
395395
target values.
396+
* Huber (``'huber'``): Another robust loss function that combines
397+
least squares and least absolute deviation; use ``alpha`` to
398+
control the sensitivity w.r.t. outliers (see [F2001]_ for more
399+
details).
400+
* Quantile (``'quantile'``): A loss function for quantile regression.
401+
Use ``0 < alpha < 1`` to specify the quantile. This loss function
402+
can be used to create prediction intervals.
396403

397404
* Classification
398405

@@ -443,8 +450,7 @@ Subsampling
443450
[F1999]_ proposed stochastic gradient boosting, which combines gradient
444451
boosting with bootstrap averaging (bagging). At each iteration
445452
the base classifier is trained on a fraction ``subsample`` of
446-
the available training data.
447-
The subsample is drawn without replacement.
453+
the available training data. The subsample is drawn without replacement.
448454
A typical value of ``subsample`` is 0.5.
449455

450456
The figure below illustrates the effect of shrinkage and subsampling
@@ -458,12 +464,21 @@ does poorly.
458464
:align: center
459465
:scale: 75
460466

467+
For ``subsample < 1``, the deviance on the out-of-bag samples in the i-the iteration
468+
is stored in the attribute ``oob_score_[i]``. Out-of-bag estimates can be
469+
used for model selection (e.g. to determine the optimal number of iterations).
470+
471+
Another strategy to reduce the variance is by subsampling the features
472+
analogous to the random splits in Random Forests. The size of the subsample
473+
can be controled via the ``max_features`` parameter.
474+
461475

462476
.. topic:: Examples:
463477

464478
* :ref:`example_ensemble_plot_gradient_boosting_regression.py`
465479
* :ref:`example_ensemble_plot_gradient_boosting_regularization.py`
466480

481+
467482
.. topic:: References
468483

469484
.. [F2001] J. Friedman, "Greedy Function Approximation: A Gradient Boosting Machine",

doc/whats_new.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,13 @@
99
Changelog
1010
---------
1111

12+
- :class:`ensemble.GradientBoostingRegressor` and
13+
:class:`ensemble.GradientBoostingClassifier` now support feature subsampling
14+
via the ``max_features`` argument.
15+
16+
- Added Huber and Quantile loss functions to
17+
:class:`ensemble.GradientBoostingRegressor`.
18+
1219
- Added :class:`preprocessing.LabelBinarizer`, a simple utility class to
1320
normalize labels or transform non-numerical labels, by `Mathieu Blondel`_.
1421

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""
2+
=====================================================
3+
Prediction Intervals for Gradient Boosting Regression
4+
=====================================================
5+
6+
This example shows how quantile regression can be used
7+
to create prediction intervals.
8+
"""
9+
10+
import numpy as np
11+
import pylab as pl
12+
from sklearn.ensemble import GradientBoostingRegressor
13+
14+
15+
np.random.seed(1)
16+
17+
18+
def f(x):
19+
"""The function to predict."""
20+
return x * np.sin(x)
21+
22+
#----------------------------------------------------------------------
23+
# First the noiseless case
24+
X = np.atleast_2d(np.random.uniform(0, 10.0, size=100)).T
25+
X = X.astype(np.float32)
26+
27+
# Observations
28+
y = f(X).ravel()
29+
30+
dy = 1.5 + 1.0 * np.random.random(y.shape)
31+
noise = np.random.normal(0, dy)
32+
y += noise
33+
y = y.astype(np.float32)
34+
35+
# Mesh the input space for evaluations of the real function, the prediction and
36+
# its MSE
37+
xx = np.atleast_2d(np.linspace(0, 10, 1000)).T
38+
xx = xx.astype(np.float32)
39+
40+
alpha = 0.95
41+
42+
clf = GradientBoostingRegressor(loss='quantile', alpha=alpha,
43+
n_estimators=250, max_depth=3,
44+
learn_rate=.1, min_samples_leaf=9,
45+
min_samples_split=9)
46+
47+
clf.fit(X, y)
48+
49+
# Make the prediction on the meshed x-axis
50+
y_upper = clf.predict(xx)
51+
52+
clf.set_params(alpha=1.0 - alpha)
53+
clf.fit(X, y)
54+
55+
# Make the prediction on the meshed x-axis
56+
y_lower = clf.predict(xx)
57+
58+
clf.set_params(loss='ls')
59+
clf.fit(X, y)
60+
61+
# Make the prediction on the meshed x-axis
62+
y_pred = clf.predict(xx)
63+
64+
# Plot the function, the prediction and the 95% confidence interval based on
65+
# the MSE
66+
fig = pl.figure()
67+
pl.plot(xx, f(xx), 'g:', label=u'$f(x) = x\,\sin(x)$')
68+
pl.plot(X, y, 'b.', markersize=10, label=u'Observations')
69+
pl.plot(xx, y_pred, 'r-', label=u'Prediction')
70+
pl.plot(xx, y_upper, 'k-')
71+
pl.plot(xx, y_lower, 'k-')
72+
pl.fill(np.concatenate([xx, xx[::-1]]),
73+
np.concatenate([y_upper, y_lower[::-1]]),
74+
alpha=.5, fc='b', ec='None', label='95% prediction interval')
75+
pl.xlabel('$x$')
76+
pl.ylabel('$f(x)$')
77+
pl.ylim(-10, 20)
78+
pl.legend(loc='upper left')
79+
pl.show()

examples/ensemble/plot_gradient_boosting_regularization.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,15 @@
66
Illustration of the effect of different regularization strategies
77
for Gradient Boosting. The example is taken from Hastie et al 2009.
88
9-
The loss function used is binomial deviance. In combination with
10-
shrinkage, stochastic gradient boosting (Sample 0.5) can produce
11-
more accurate models.
9+
The loss function used is binomial deviance. Regularization via
10+
shrinkage (``learn_rate < 1.0``) improves performance considerably.
11+
In combination with shrinkage, stochastic gradient boosting
12+
(``subsample < 1.0``) can produce more accurate models by reducing the
13+
variance via bagging.
1214
Subsampling without shrinkage usually does poorly.
15+
Another strategy to reduce the variance is by subsampling the features
16+
analogous to the random splits in Random Forests
17+
(via the ``max_features`` parameter).
1318
1419
.. [1] T. Hastie, R. Tibshirani and J. Friedman, "Elements of Statistical
1520
Learning Ed. 2", Springer, 2009.
@@ -39,12 +44,14 @@
3944

4045
for label, color, setting in [('No shrinkage', 'orange',
4146
{'learn_rate': 1.0, 'subsample': 1.0}),
42-
('Shrink=0.1', 'turquoise',
47+
('learn_rate=0.1', 'turquoise',
4348
{'learn_rate': 0.1, 'subsample': 1.0}),
44-
('Sample=0.5', 'blue',
49+
('subsample=0.5', 'blue',
4550
{'learn_rate': 1.0, 'subsample': 0.5}),
46-
('Shrink=0.1, Sample=0.5', 'gray',
47-
{'learn_rate': 0.1, 'subsample': 0.5})]:
51+
('learn_rate=0.1, subsample=0.5', 'gray',
52+
{'learn_rate': 0.1, 'subsample': 0.5}),
53+
('learn_rate=0.1, max_features=2', 'magenta',
54+
{'learn_rate': 0.1, 'max_features': 2})]:
4855
params = dict(original_params)
4956
params.update(setting)
5057

@@ -57,10 +64,9 @@
5764
for i, y_pred in enumerate(clf.staged_decision_function(X_test)):
5865
test_deviance[i] = clf.loss_(y_test, y_pred)
5966

60-
pl.plot(np.arange(test_deviance.shape[0]) + 1, test_deviance, '-',
67+
pl.plot((np.arange(test_deviance.shape[0]) + 1)[::5], test_deviance[::5], '-',
6168
color=color, label=label)
6269

63-
pl.title('Deviance')
6470
pl.legend(loc='upper left')
6571
pl.xlabel('Boosting Iterations')
6672
pl.ylabel('Test Set Deviance')

sklearn/ensemble/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,5 @@
88
from .forest import RandomForestRegressor
99
from .forest import ExtraTreesClassifier
1010
from .forest import ExtraTreesRegressor
11-
1211
from .gradient_boosting import GradientBoostingClassifier
1312
from .gradient_boosting import GradientBoostingRegressor

0 commit comments

Comments
 (0)