Skip to content

Commit 9b7a86f

Browse files
saskrajeremiedbblucyleeow
authored
Fix spurious warning from type_of_target when called on estimator.classes_ (scikit-learn#31584)
Co-authored-by: Jérémie du Boisberranger <[email protected]> Co-authored-by: Lucy Liu <[email protected]>
1 parent f187311 commit 9b7a86f

File tree

4 files changed

+41
-2
lines changed

4 files changed

+41
-2
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- Fixed a spurious warning (about the number of unique classes being
2+
greater than 50% of the number of samples) that could occur when
3+
passing `classes` :func:`utils.multiclass.type_of_target`.
4+
By :user:`Sascha D. Krauss <saskra>`.

sklearn/utils/multiclass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ def _raise_or_return():
414414
if issparse(first_row_or_val):
415415
first_row_or_val = first_row_or_val.data
416416
classes = cached_unique(y)
417-
if y.shape[0] > 20 and classes.shape[0] > round(0.5 * y.shape[0]):
417+
if y.shape[0] > 20 and y.shape[0] > classes.shape[0] > round(0.5 * y.shape[0]):
418418
# Only raise the warning when we have at least 20 samples.
419419
warnings.warn(
420420
"The number of unique classes is greater than 50% of the number "

sklearn/utils/tests/test_multiclass.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,11 @@ def test_type_of_target_too_many_unique_classes():
302302
We need to check that we don't raise if we have less than 20 samples.
303303
"""
304304

305-
y = np.arange(25)
305+
# Create array of unique labels, except '0', which appears twice.
306+
# This does raise a warning.
307+
# Note warning would not be raised if we passed only unique
308+
# labels, which happens when `type_of_target` is passed `classes_`.
309+
y = np.hstack((np.arange(20), [0]))
306310
msg = r"The number of unique classes is greater than 50% of the number of samples."
307311
with pytest.warns(UserWarning, match=msg):
308312
type_of_target(y)
@@ -313,6 +317,14 @@ def test_type_of_target_too_many_unique_classes():
313317
warnings.simplefilter("error")
314318
type_of_target(y)
315319

320+
# More than 20 samples but only unique classes, simulating passing
321+
# `classes_` to `type_of_target` (when number of classes is large).
322+
# No warning should be raised
323+
y = np.arange(25)
324+
with warnings.catch_warnings():
325+
warnings.simplefilter("ignore", UserWarning)
326+
type_of_target(y)
327+
316328

317329
def test_unique_labels_non_specific():
318330
# Test unique_labels with a variety of collected examples

sklearn/utils/tests/test_response.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
import numpy as np
24
import pytest
35

@@ -369,3 +371,24 @@ def test_get_response_values_multilabel_indicator(response_method):
369371
assert (y_pred > 1).sum() > 0
370372
else: # response_method == "predict"
371373
assert np.logical_or(y_pred == 0, y_pred == 1).all()
374+
375+
376+
def test_response_values_type_of_target_on_classes_no_warning():
377+
"""
378+
Ensure `_get_response_values` doesn't raise spurious warning.
379+
380+
"The number of unique classes is greater than > 50% of samples"
381+
warning should not be raised when calling `type_of_target(classes_)`.
382+
383+
Non-regression test for issue #31583.
384+
"""
385+
X = np.random.RandomState(0).randn(120, 3)
386+
# 30 classes, less than 50% of number of samples
387+
y = np.repeat(np.arange(30), 4)
388+
389+
clf = LogisticRegression().fit(X, y)
390+
391+
with warnings.catch_warnings():
392+
warnings.simplefilter("error", UserWarning)
393+
394+
_get_response_values(clf, X, response_method="predict_proba")

0 commit comments

Comments
 (0)