@@ -1003,7 +1003,7 @@ def fit(self, X, y, sample_weight=None, monitor=None):
10031003
10041004 check_consistent_length (X , y , sample_weight )
10051005
1006- y = self ._validate_y (y )
1006+ y = self ._validate_y (y , sample_weight )
10071007
10081008 if self .n_iter_no_change is not None :
10091009 X , X_val , y , y_val , sample_weight , sample_weight_val = (
@@ -1237,7 +1237,9 @@ def feature_importances_(self):
12371237 importances = total_sum / len (self .estimators_ )
12381238 return importances
12391239
1240- def _validate_y (self , y ):
1240+ def _validate_y (self , y , sample_weight ):
1241+ # 'sample_weight' is not utilised but is used for
1242+ # consistency with similar method _validate_y of GBC
12411243 self .n_classes_ = 1
12421244 if y .dtype .kind == 'O' :
12431245 y = y .astype (np .float64 )
@@ -1551,9 +1553,15 @@ def __init__(self, loss='deviance', learning_rate=0.1, n_estimators=100,
15511553 validation_fraction = validation_fraction ,
15521554 n_iter_no_change = n_iter_no_change , tol = tol )
15531555
1554- def _validate_y (self , y ):
1556+ def _validate_y (self , y , sample_weight ):
15551557 check_classification_targets (y )
15561558 self .classes_ , y = np .unique (y , return_inverse = True )
1559+ n_trim_classes = np .count_nonzero (np .bincount (y , sample_weight ))
1560+ if n_trim_classes < 2 :
1561+ raise ValueError ("y contains %d class after sample_weight "
1562+ "trimmed classes with zero weights, while a "
1563+ "minimum of 2 classes are required."
1564+ % n_trim_classes )
15571565 self .n_classes_ = len (self .classes_ )
15581566 return y
15591567
0 commit comments