Skip to content

Commit c9e6d4d

Browse files
gxydjnothman
authored andcommitted
ENH Make GradientBoostingClassifier error message more informative (scikit-learn#10207)
1 parent b933b09 commit c9e6d4d

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

sklearn/ensemble/gradient_boosting.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,7 +1003,7 @@ def fit(self, X, y, sample_weight=None, monitor=None):
10031003

10041004
check_consistent_length(X, y, sample_weight)
10051005

1006-
y = self._validate_y(y)
1006+
y = self._validate_y(y, sample_weight)
10071007

10081008
if self.n_iter_no_change is not None:
10091009
X, X_val, y, y_val, sample_weight, sample_weight_val = (
@@ -1237,7 +1237,9 @@ def feature_importances_(self):
12371237
importances = total_sum / len(self.estimators_)
12381238
return importances
12391239

1240-
def _validate_y(self, y):
1240+
def _validate_y(self, y, sample_weight):
1241+
# 'sample_weight' is not utilised but is used for
1242+
# consistency with similar method _validate_y of GBC
12411243
self.n_classes_ = 1
12421244
if y.dtype.kind == 'O':
12431245
y = y.astype(np.float64)
@@ -1551,9 +1553,15 @@ def __init__(self, loss='deviance', learning_rate=0.1, n_estimators=100,
15511553
validation_fraction=validation_fraction,
15521554
n_iter_no_change=n_iter_no_change, tol=tol)
15531555

1554-
def _validate_y(self, y):
1556+
def _validate_y(self, y, sample_weight):
15551557
check_classification_targets(y)
15561558
self.classes_, y = np.unique(y, return_inverse=True)
1559+
n_trim_classes = np.count_nonzero(np.bincount(y, sample_weight))
1560+
if n_trim_classes < 2:
1561+
raise ValueError("y contains %d class after sample_weight "
1562+
"trimmed classes with zero weights, while a "
1563+
"minimum of 2 classes are required."
1564+
% n_trim_classes)
15571565
self.n_classes_ = len(self.classes_)
15581566
return y
15591567

sklearn/ensemble/tests/test_gradient_boosting.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
from sklearn.utils.testing import assert_equal
2727
from sklearn.utils.testing import assert_greater
2828
from sklearn.utils.testing import assert_less
29-
from sklearn.utils.testing import assert_raise_message
3029
from sklearn.utils.testing import assert_raises
30+
from sklearn.utils.testing import assert_raise_message
3131
from sklearn.utils.testing import assert_true
3232
from sklearn.utils.testing import assert_warns
3333
from sklearn.utils.testing import assert_warns_message
@@ -364,6 +364,12 @@ def test_check_inputs():
364364
assert_raises(ValueError, clf.fit, X, y,
365365
sample_weight=([1] * len(y)) + [0, 1])
366366

367+
weight = [0, 0, 0, 1, 1, 1]
368+
clf = GradientBoostingClassifier(n_estimators=100, random_state=1)
369+
msg = ("y contains 1 class after sample_weight trimmed classes with "
370+
"zero weights, while a minimum of 2 classes are required.")
371+
assert_raise_message(ValueError, msg, clf.fit, X, y, sample_weight=weight)
372+
367373

368374
def test_check_inputs_predict():
369375
# X has wrong shape

0 commit comments

Comments
 (0)