Skip to content

Commit fd3a816

Browse files
committed
renamed supported_loss to _SUPPORTED_LOSS (constants)
add test for deprecated warning
1 parent b952cdb commit fd3a816

File tree

2 files changed

+25
-19
lines changed

2 files changed

+25
-19
lines changed

sklearn/ensemble/gradient_boosting.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -498,10 +498,14 @@ def _check_params(self):
498498
if self.learning_rate <= 0.0:
499499
raise ValueError("learning_rate must be greater than 0")
500500

501-
if (self.loss not in self.supported_loss or
501+
if (self.loss not in self._SUPPORTED_LOSS or
502502
self.loss not in LOSS_FUNCTIONS):
503503
raise ValueError("Loss '{0:s}' not supported. ".format(self.loss))
504504

505+
if self.loss in ('mdeviance', 'bdeviance'):
506+
warn(("Loss '{0:s}' is deprecated as of version 0.14. "
507+
"Use 'deviance' instead. ").format(self.loss))
508+
505509
if self.loss == 'deviance':
506510
loss_class = (MultinomialDeviance
507511
if len(self.classes_) > 2
@@ -854,7 +858,7 @@ class GradientBoostingClassifier(BaseGradientBoosting, ClassifierMixin):
854858
Elements of Statistical Learning Ed. 2, Springer, 2009.
855859
"""
856860

857-
supported_loss = ('deviance', 'mdeviance', 'bdeviance')
861+
_SUPPORTED_LOSS = ('deviance', 'mdeviance', 'bdeviance')
858862

859863
def __init__(self, loss='deviance', learning_rate=0.1, n_estimators=100,
860864
subsample=1.0, min_samples_split=2, min_samples_leaf=1,
@@ -1095,7 +1099,7 @@ class GradientBoostingRegressor(BaseGradientBoosting, RegressorMixin):
10951099
Elements of Statistical Learning Ed. 2, Springer, 2009.
10961100
"""
10971101

1098-
supported_loss = ('ls', 'lad', 'huber', 'quantile')
1102+
_SUPPORTED_LOSS = ('ls', 'lad', 'huber', 'quantile')
10991103

11001104
def __init__(self, loss='ls', learning_rate=0.1, n_estimators=100,
11011105
subsample=1.0, min_samples_split=2, min_samples_leaf=1,

sklearn/ensemble/tests/test_gradient_boosting.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -103,18 +103,6 @@ def test_parameter_checks():
103103
assert_raises(ValueError,
104104
lambda: GradientBoostingClassifier().feature_importances_)
105105

106-
# binomial deviance requires ``n_classes == 2``.
107-
assert_raises(ValueError,
108-
lambda X, y: GradientBoostingClassifier(
109-
loss='bdeviance').fit(X, y),
110-
X, [0, 0, 1, 1, 2, 2])
111-
112-
# multinomial deviance requires ``n_classes > 2``.
113-
assert_raises(ValueError,
114-
lambda X, y: GradientBoostingClassifier(
115-
loss='mdeviance').fit(X, y),
116-
X, [0, 0, 1, 1, 1, 0])
117-
118106
# deviance requires ``n_classes >= 2``.
119107
assert_raises(ValueError,
120108
lambda X, y: GradientBoostingClassifier(
@@ -133,10 +121,6 @@ def test_loss_function():
133121
GradientBoostingClassifier(loss='huber').fit, X, y)
134122
assert_raises(ValueError,
135123
GradientBoostingRegressor(loss='deviance').fit, X, y)
136-
assert_raises(ValueError,
137-
GradientBoostingRegressor(loss='bdeviance').fit, X, y)
138-
assert_raises(ValueError,
139-
GradientBoostingRegressor(loss='mdeviance').fit, X, y)
140124

141125

142126
def test_classification_synthetic():
@@ -596,3 +580,21 @@ def test_more_verbose_output():
596580
n_lines = sum(1 for l in verbose_output.readlines())
597581
# 100 lines for n_estimators==100
598582
assert_equal(100, n_lines)
583+
584+
585+
def test_warn_deviance():
586+
"""Test if mdeviance and bdeviance give deprecated warning. """
587+
for loss in ('bdeviance', 'mdeviance'):
588+
with warnings.catch_warnings(record=True) as w:
589+
# This will raise a DataConversionWarning that we want to
590+
# "always" raise, elsewhere the warnings gets ignored in the
591+
# later tests, and the tests that check for this warning fail
592+
warnings.simplefilter("always", DataConversionWarning)
593+
clf = GradientBoostingClassifier(loss=loss)
594+
try:
595+
clf.fit(X, y)
596+
except:
597+
# mdeviance will raise ValueError because only 2 classes
598+
pass
599+
# deprecated warning for bdeviance and mdeviance
600+
assert len(w) == 1

0 commit comments

Comments
 (0)