Skip to content

Commit 5cf88db

Browse files
NicolasHugglemaitre
authored andcommitted
EHN Implement least absolute deviation loss in GBDTs (scikit-learn#13896)
1 parent 813b601 commit 5cf88db

File tree

10 files changed

+160
-30
lines changed

10 files changed

+160
-30
lines changed

benchmarks/bench_hist_gradient_boosting.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
parser.add_argument('--learning-rate', type=float, default=.1)
2727
parser.add_argument('--problem', type=str, default='classification',
2828
choices=['classification', 'regression'])
29+
parser.add_argument('--loss', type=str, default='default')
2930
parser.add_argument('--missing-fraction', type=float, default=0)
3031
parser.add_argument('--n-classes', type=int, default=2)
3132
parser.add_argument('--n-samples-max', type=int, default=int(1e6))
@@ -81,6 +82,17 @@ def one_run(n_samples):
8182
n_iter_no_change=None,
8283
random_state=0,
8384
verbose=0)
85+
loss = args.loss
86+
if args.problem == 'classification':
87+
if loss == 'default':
88+
# loss='auto' does not work with get_equivalent_estimator()
89+
loss = 'binary_crossentropy' if args.n_classes == 2 else \
90+
'categorical_crossentropy'
91+
else:
92+
# regression
93+
if loss == 'default':
94+
loss = 'least_squares'
95+
est.set_params(loss=loss)
8496
est.fit(X_train, y_train)
8597
sklearn_fit_duration = time() - tic
8698
tic = time()
@@ -95,11 +107,6 @@ def one_run(n_samples):
95107
lightgbm_score_duration = None
96108
if args.lightgbm:
97109
print("Fitting a LightGBM model...")
98-
# get_lightgbm does not accept loss='auto'
99-
if args.problem == 'classification':
100-
loss = 'binary_crossentropy' if args.n_classes == 2 else \
101-
'categorical_crossentropy'
102-
est.set_params(loss=loss)
103110
lightgbm_est = get_equivalent_estimator(est, lib='lightgbm')
104111

105112
tic = time()
@@ -117,11 +124,6 @@ def one_run(n_samples):
117124
xgb_score_duration = None
118125
if args.xgboost:
119126
print("Fitting an XGBoost model...")
120-
# get_xgb does not accept loss='auto'
121-
if args.problem == 'classification':
122-
loss = 'binary_crossentropy' if args.n_classes == 2 else \
123-
'categorical_crossentropy'
124-
est.set_params(loss=loss)
125127
xgb_est = get_equivalent_estimator(est, lib='xgboost')
126128

127129
tic = time()
@@ -139,11 +141,6 @@ def one_run(n_samples):
139141
cat_score_duration = None
140142
if args.catboost:
141143
print("Fitting a CatBoost model...")
142-
# get_cat does not accept loss='auto'
143-
if args.problem == 'classification':
144-
loss = 'binary_crossentropy' if args.n_classes == 2 else \
145-
'categorical_crossentropy'
146-
est.set_params(loss=loss)
147144
cat_est = get_equivalent_estimator(est, lib='catboost')
148145

149146
tic = time()

doc/modules/ensemble.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,6 +878,13 @@ controls the number of iterations of the boosting process::
878878
>>> clf.score(X_test, y_test)
879879
0.8965
880880

881+
Available losses for regression are 'least_squares' and
882+
'least_absolute_deviation', which is less sensitive to outliers. For
883+
classification, 'binary_crossentropy' is used for binary classification and
884+
'categorical_crossentropy' is used for multiclass classification. By default
885+
the loss is 'auto' and will select the appropriate loss depending on
886+
:term:`y` passed to :term:`fit`.
887+
881888
The size of the trees can be controlled through the ``max_leaf_nodes``,
882889
``max_depth``, and ``min_samples_leaf`` parameters.
883890

doc/whats_new/v0.22.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ Changelog
136136
- |Feature| :func:`inspection.partial_dependence` and
137137
:func:`inspection.plot_partial_dependence` now support the fast 'recursion'
138138
method for both estimators. :pr:`13769` by `Nicolas Hug`_.
139+
- |Enhancement| :class:`ensemble.HistGradientBoostingRegressor` now supports
140+
the 'least_absolute_deviation' loss. :pr:`13896` by `Nicolas Hug`_.
139141
- |Fix| Estimators now bin the training and validation data separately to
140142
avoid any data leak. :pr:`13933` by `Nicolas Hug`_.
141143
- |Fix| Fixed a bug where early stopping would break with string targets.

sklearn/ensemble/_hist_gradient_boosting/_loss.pyx

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,24 @@ def _update_gradients_least_squares(
2727

2828
n_samples = raw_predictions.shape[0]
2929
for i in prange(n_samples, schedule='static', nogil=True):
30-
# Note: a more correct exp is 2 * (raw_predictions - y_true) but
31-
# since we use 1 for the constant hessian value (and not 2) this
32-
# is strictly equivalent for the leaves values.
3330
gradients[i] = raw_predictions[i] - y_true[i]
3431

3532

33+
def _update_gradients_least_absolute_deviation(
34+
G_H_DTYPE_C [::1] gradients, # OUT
35+
const Y_DTYPE_C [::1] y_true, # IN
36+
const Y_DTYPE_C [::1] raw_predictions): # IN
37+
38+
cdef:
39+
int n_samples
40+
int i
41+
42+
n_samples = raw_predictions.shape[0]
43+
for i in prange(n_samples, schedule='static', nogil=True):
44+
# gradient = sign(raw_predicition - y_pred)
45+
gradients[i] = 2 * (y_true[i] - raw_predictions[i] < 0) - 1
46+
47+
3648
def _update_gradients_hessians_binary_crossentropy(
3749
G_H_DTYPE_C [::1] gradients, # OUT
3850
G_H_DTYPE_C [::1] hessians, # OUT

sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,10 @@ def fit(self, X, y):
322322
acc_find_split_time += grower.total_find_split_time
323323
acc_compute_hist_time += grower.total_compute_hist_time
324324

325+
if self.loss_.need_update_leaves_values:
326+
self.loss_.update_leaves_values(grower, y_train,
327+
raw_predictions[k, :])
328+
325329
predictor = grower.make_predictor(
326330
bin_thresholds=self.bin_mapper_.bin_thresholds_
327331
)
@@ -672,7 +676,8 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting):
672676
673677
Parameters
674678
----------
675-
loss : {'least_squares'}, optional (default='least_squares')
679+
loss : {'least_squares', 'least_absolute_deviation'}, \
680+
optional (default='least_squares')
676681
The loss function to use in the boosting process. Note that the
677682
"least squares" loss actually implements an "half least squares loss"
678683
to simplify the computation of the gradient.
@@ -770,7 +775,7 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting):
770775
0.98...
771776
"""
772777

773-
_VALID_LOSSES = ('least_squares',)
778+
_VALID_LOSSES = ('least_squares', 'least_absolute_deviation')
774779

775780
def __init__(self, loss='least_squares', learning_rate=0.1,
776781
max_iter=100, max_leaf_nodes=31, max_depth=None,

sklearn/ensemble/_hist_gradient_boosting/loss.py

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,24 @@
1818
from .common import Y_DTYPE
1919
from .common import G_H_DTYPE
2020
from ._loss import _update_gradients_least_squares
21+
from ._loss import _update_gradients_least_absolute_deviation
2122
from ._loss import _update_gradients_hessians_binary_crossentropy
2223
from ._loss import _update_gradients_hessians_categorical_crossentropy
2324

2425

2526
class BaseLoss(ABC):
2627
"""Base class for a loss."""
2728

29+
# This variable indicates whether the loss requires the leaves values to
30+
# be updated once the tree has been trained. The trees are trained to
31+
# predict a Newton-Raphson step (see grower._finalize_leaf()). But for
32+
# some losses (e.g. least absolute deviation) we need to adjust the tree
33+
# values to account for the "line search" of the gradient descent
34+
# procedure. See the original paper Greedy Function Approximation: A
35+
# Gradient Boosting Machine by Friedman
36+
# (https://statweb.stanford.edu/~jhf/ftp/trebst.pdf) for the theory.
37+
need_update_leaves_values = False
38+
2839
def init_gradients_and_hessians(self, n_samples, prediction_dim):
2940
"""Return initial gradients and hessians.
3041
@@ -53,9 +64,10 @@ def init_gradients_and_hessians(self, n_samples, prediction_dim):
5364
shape = (prediction_dim, n_samples)
5465
gradients = np.empty(shape=shape, dtype=G_H_DTYPE)
5566
if self.hessians_are_constant:
56-
# if the hessians are constant, we consider they are equal to 1.
57-
# this is correct as long as we adjust the gradients. See e.g. LS
58-
# loss
67+
# If the hessians are constant, we consider they are equal to 1.
68+
# - This is correct for the half LS loss
69+
# - For LAD loss, hessians are actually 0, but they are always
70+
# ignored anyway.
5971
hessians = np.ones(shape=(1, 1), dtype=G_H_DTYPE)
6072
else:
6173
hessians = np.empty(shape=shape, dtype=G_H_DTYPE)
@@ -141,6 +153,63 @@ def update_gradients_and_hessians(self, gradients, hessians, y_true,
141153
_update_gradients_least_squares(gradients, y_true, raw_predictions)
142154

143155

156+
class LeastAbsoluteDeviation(BaseLoss):
157+
"""Least asbolute deviation, for regression.
158+
159+
For a given sample x_i, the loss is defined as::
160+
161+
loss(x_i) = |y_true_i - raw_pred_i|
162+
"""
163+
164+
hessians_are_constant = True
165+
# This variable indicates whether the loss requires the leaves values to
166+
# be updated once the tree has been trained. The trees are trained to
167+
# predict a Newton-Raphson step (see grower._finalize_leaf()). But for
168+
# some losses (e.g. least absolute deviation) we need to adjust the tree
169+
# values to account for the "line search" of the gradient descent
170+
# procedure. See the original paper Greedy Function Approximation: A
171+
# Gradient Boosting Machine by Friedman
172+
# (https://statweb.stanford.edu/~jhf/ftp/trebst.pdf) for the theory.
173+
need_update_leaves_values = True
174+
175+
def __call__(self, y_true, raw_predictions, average=True):
176+
# shape (1, n_samples) --> (n_samples,). reshape(-1) is more likely to
177+
# return a view.
178+
raw_predictions = raw_predictions.reshape(-1)
179+
loss = np.abs(y_true - raw_predictions)
180+
return loss.mean() if average else loss
181+
182+
def get_baseline_prediction(self, y_train, prediction_dim):
183+
return np.median(y_train)
184+
185+
@staticmethod
186+
def inverse_link_function(raw_predictions):
187+
return raw_predictions
188+
189+
def update_gradients_and_hessians(self, gradients, hessians, y_true,
190+
raw_predictions):
191+
# shape (1, n_samples) --> (n_samples,). reshape(-1) is more likely to
192+
# return a view.
193+
raw_predictions = raw_predictions.reshape(-1)
194+
gradients = gradients.reshape(-1)
195+
_update_gradients_least_absolute_deviation(gradients, y_true,
196+
raw_predictions)
197+
198+
def update_leaves_values(self, grower, y_true, raw_predictions):
199+
# Update the values predicted by the tree with
200+
# median(y_true - raw_predictions).
201+
# See note about need_update_leaves_values in BaseLoss.
202+
203+
# TODO: ideally this should be computed in parallel over the leaves
204+
# using something similar to _update_raw_predictions(), but this
205+
# requires a cython version of median()
206+
for leaf in grower.finalized_leaves:
207+
indices = leaf.sample_indices
208+
median_res = np.median(y_true[indices] - raw_predictions[indices])
209+
leaf.value = grower.shrinkage * median_res
210+
# Note that the regularization is ignored here
211+
212+
144213
class BinaryCrossEntropy(BaseLoss):
145214
"""Binary cross-entropy loss, for binary classification.
146215
@@ -242,6 +311,7 @@ def predict_proba(self, raw_predictions):
242311

243312
_LOSSES = {
244313
'least_squares': LeastSquares,
314+
'least_absolute_deviation': LeastAbsoluteDeviation,
245315
'binary_crossentropy': BinaryCrossEntropy,
246316
'categorical_crossentropy': CategoricalCrossEntropy
247317
}

sklearn/ensemble/_hist_gradient_boosting/tests/test_compare_lightgbm.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,13 @@ def test_same_predictions_regression(seed, min_samples_leaf, n_samples,
3939
# and max_leaf_nodes is low enough.
4040
# - To ignore discrepancies caused by small differences the binning
4141
# strategy, data is pre-binned if n_samples > 255.
42+
# - We don't check the least_absolute_deviation loss here. This is because
43+
# LightGBM's computation of the median (used for the initial value of
44+
# raw_prediction) is a bit off (they'll e.g. return midpoints when there
45+
# is no need to.). Since these tests only run 1 iteration, the
46+
# discrepancy between the initial values leads to biggish differences in
47+
# the predictions. These differences are much smaller with more
48+
# iterations.
4249

4350
rng = np.random.RandomState(seed=seed)
4451
n_samples = n_samples

sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,15 @@ def test_should_stop(scores, n_iter_no_change, tol, stopping):
155155
assert gbdt._should_stop(scores) == stopping
156156

157157

158+
def test_least_absolute_deviation():
159+
# For coverage only.
160+
X, y = make_regression(n_samples=500, random_state=0)
161+
gbdt = HistGradientBoostingRegressor(loss='least_absolute_deviation',
162+
random_state=0)
163+
gbdt.fit(X, y)
164+
assert gbdt.score(X, y) > .9
165+
166+
158167
def test_binning_train_validation_are_separated():
159168
# Make sure training and validation data are binned separately.
160169
# See issue 13926

sklearn/ensemble/_hist_gradient_boosting/tests/test_loss.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,12 @@ def get_hessians(y_true, raw_predictions):
3232

3333
if loss.__class__.__name__ == 'LeastSquares':
3434
# hessians aren't updated because they're constant:
35-
# the value is 1 because the loss is actually an half
35+
# the value is 1 (and not 2) because the loss is actually an half
3636
# least squares loss.
3737
hessians = np.full_like(raw_predictions, fill_value=1)
38+
elif loss.__class__.__name__ == 'LeastAbsoluteDeviation':
39+
# hessians aren't updated because they're constant
40+
hessians = np.full_like(raw_predictions, fill_value=0)
3841

3942
return hessians
4043

@@ -81,6 +84,7 @@ def fprime2(x):
8184

8285
@pytest.mark.parametrize('loss, n_classes, prediction_dim', [
8386
('least_squares', 0, 1),
87+
('least_absolute_deviation', 0, 1),
8488
('binary_crossentropy', 2, 1),
8589
('categorical_crossentropy', 3, 3),
8690
])
@@ -94,7 +98,7 @@ def test_numerical_gradients(loss, n_classes, prediction_dim, seed=0):
9498

9599
rng = np.random.RandomState(seed)
96100
n_samples = 100
97-
if loss == 'least_squares':
101+
if loss in ('least_squares', 'least_absolute_deviation'):
98102
y_true = rng.normal(size=n_samples).astype(Y_DTYPE)
99103
else:
100104
y_true = rng.randint(0, n_classes, size=n_samples).astype(Y_DTYPE)
@@ -128,11 +132,8 @@ def test_numerical_gradients(loss, n_classes, prediction_dim, seed=0):
128132
f = loss(y_true, raw_predictions, average=False)
129133
numerical_hessians = (f_plus_eps + f_minus_eps - 2 * f) / eps**2
130134

131-
def relative_error(a, b):
132-
return np.abs(a - b) / np.maximum(np.abs(a), np.abs(b))
133-
134-
assert_allclose(numerical_gradients, gradients, rtol=1e-4)
135-
assert_allclose(numerical_hessians, hessians, rtol=1e-4)
135+
assert_allclose(numerical_gradients, gradients, rtol=1e-4, atol=1e-7)
136+
assert_allclose(numerical_hessians, hessians, rtol=1e-4, atol=1e-7)
136137

137138

138139
def test_baseline_least_squares():
@@ -145,6 +146,22 @@ def test_baseline_least_squares():
145146
assert baseline_prediction.dtype == y_train.dtype
146147
# Make sure baseline prediction is the mean of all targets
147148
assert_almost_equal(baseline_prediction, y_train.mean())
149+
assert np.allclose(loss.inverse_link_function(baseline_prediction),
150+
baseline_prediction)
151+
152+
153+
def test_baseline_least_absolute_deviation():
154+
rng = np.random.RandomState(0)
155+
156+
loss = _LOSSES['least_absolute_deviation']()
157+
y_train = rng.normal(size=100)
158+
baseline_prediction = loss.get_baseline_prediction(y_train, 1)
159+
assert baseline_prediction.shape == tuple() # scalar
160+
assert baseline_prediction.dtype == y_train.dtype
161+
# Make sure baseline prediction is the median of all targets
162+
assert np.allclose(loss.inverse_link_function(baseline_prediction),
163+
baseline_prediction)
164+
assert baseline_prediction == pytest.approx(np.median(y_train))
148165

149166

150167
def test_baseline_binary_crossentropy():

sklearn/ensemble/_hist_gradient_boosting/utils.pyx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def get_equivalent_estimator(estimator, lib='lightgbm'):
4343

4444
lightgbm_loss_mapping = {
4545
'least_squares': 'regression_l2',
46+
'least_absolute_deviation': 'regression_l1',
4647
'binary_crossentropy': 'binary',
4748
'categorical_crossentropy': 'multiclass'
4849
}
@@ -75,6 +76,7 @@ def get_equivalent_estimator(estimator, lib='lightgbm'):
7576
# XGB
7677
xgboost_loss_mapping = {
7778
'least_squares': 'reg:linear',
79+
'least_absolute_deviation': 'LEAST_ABSOLUTE_DEV_NOT_SUPPORTED',
7880
'binary_crossentropy': 'reg:logistic',
7981
'categorical_crossentropy': 'multi:softmax'
8082
}
@@ -98,6 +100,8 @@ def get_equivalent_estimator(estimator, lib='lightgbm'):
98100
# Catboost
99101
catboost_loss_mapping = {
100102
'least_squares': 'RMSE',
103+
# catboost does not support MAE when leaf_estimation_method is Newton
104+
'least_absolute_deviation': 'LEAST_ASBOLUTE_DEV_NOT_SUPPORTED',
101105
'binary_crossentropy': 'Logloss',
102106
'categorical_crossentropy': 'MultiClass'
103107
}

0 commit comments

Comments
 (0)