File tree Expand file tree Collapse file tree 4 files changed +41
-2
lines changed
doc/whats_new/upcoming_changes/sklearn.utils Expand file tree Collapse file tree 4 files changed +41
-2
lines changed Original file line number Diff line number Diff line change 1+ - Fixed a spurious warning (about the number of unique classes being
2+ greater than 50% of the number of samples) that could occur when
3+ passing `classes ` :func: `utils.multiclass.type_of_target `.
4+ By :user: `Sascha D. Krauss <saskra> `.
Original file line number Diff line number Diff line change @@ -414,7 +414,7 @@ def _raise_or_return():
414414 if issparse (first_row_or_val ):
415415 first_row_or_val = first_row_or_val .data
416416 classes = cached_unique (y )
417- if y .shape [0 ] > 20 and classes .shape [0 ] > round (0.5 * y .shape [0 ]):
417+ if y .shape [0 ] > 20 and y . shape [ 0 ] > classes .shape [0 ] > round (0.5 * y .shape [0 ]):
418418 # Only raise the warning when we have at least 20 samples.
419419 warnings .warn (
420420 "The number of unique classes is greater than 50% of the number "
Original file line number Diff line number Diff line change @@ -302,7 +302,11 @@ def test_type_of_target_too_many_unique_classes():
302302 We need to check that we don't raise if we have less than 20 samples.
303303 """
304304
305- y = np .arange (25 )
305+ # Create array of unique labels, except '0', which appears twice.
306+ # This does raise a warning.
307+ # Note warning would not be raised if we passed only unique
308+ # labels, which happens when `type_of_target` is passed `classes_`.
309+ y = np .hstack ((np .arange (20 ), [0 ]))
306310 msg = r"The number of unique classes is greater than 50% of the number of samples."
307311 with pytest .warns (UserWarning , match = msg ):
308312 type_of_target (y )
@@ -313,6 +317,14 @@ def test_type_of_target_too_many_unique_classes():
313317 warnings .simplefilter ("error" )
314318 type_of_target (y )
315319
320+ # More than 20 samples but only unique classes, simulating passing
321+ # `classes_` to `type_of_target` (when number of classes is large).
322+ # No warning should be raised
323+ y = np .arange (25 )
324+ with warnings .catch_warnings ():
325+ warnings .simplefilter ("ignore" , UserWarning )
326+ type_of_target (y )
327+
316328
317329def test_unique_labels_non_specific ():
318330 # Test unique_labels with a variety of collected examples
Original file line number Diff line number Diff line change 1+ import warnings
2+
13import numpy as np
24import pytest
35
@@ -369,3 +371,24 @@ def test_get_response_values_multilabel_indicator(response_method):
369371 assert (y_pred > 1 ).sum () > 0
370372 else : # response_method == "predict"
371373 assert np .logical_or (y_pred == 0 , y_pred == 1 ).all ()
374+
375+
376+ def test_response_values_type_of_target_on_classes_no_warning ():
377+ """
378+ Ensure `_get_response_values` doesn't raise spurious warning.
379+
380+ "The number of unique classes is greater than > 50% of samples"
381+ warning should not be raised when calling `type_of_target(classes_)`.
382+
383+ Non-regression test for issue #31583.
384+ """
385+ X = np .random .RandomState (0 ).randn (120 , 3 )
386+ # 30 classes, less than 50% of number of samples
387+ y = np .repeat (np .arange (30 ), 4 )
388+
389+ clf = LogisticRegression ().fit (X , y )
390+
391+ with warnings .catch_warnings ():
392+ warnings .simplefilter ("error" , UserWarning )
393+
394+ _get_response_values (clf , X , response_method = "predict_proba" )
You can’t perform that action at this time.
0 commit comments