Skip to content

Commit 24653e8

Browse files
lesteveMechCoder
authored andcommitted
[MRG] Fix sklearn.metrics.classification._check_targets where y_true and y_pred are both binary but the union is multiclass (scikit-learn#8377)
* Fix _check_targets where y_true and y_pred are both binary but the union of them is multiclass. * Add entry in changelog
1 parent 195de6a commit 24653e8

File tree

3 files changed

+18
-3
lines changed

3 files changed

+18
-3
lines changed

doc/whats_new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,10 @@ Bug fixes
260260
obstructed pickling customizations of child-classes, when used in a
261261
multiple inheritance context.
262262
:issue:`8316` by :user:`Holger Peters <HolgerPeters>`.
263+
- Fix a bug in :func:`sklearn.metrics.classification._check_targets`
264+
which would return ``'binary'`` if ``y_true`` and ``y_pred`` were
265+
both ``'binary'`` but the union of ``y_true`` and ``y_pred`` was
266+
``'multiclass'``. :issue:`8377` by `Loic Esteve`_.
263267

264268
- Fix :func:`sklearn.linear_model.BayesianRidge.fit` to return
265269
ridge parameter `alpha_` and `lambda_` consistent with calculated

sklearn/metrics/classification.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ def _check_targets(y_true, y_pred):
9191
if y_type in ["binary", "multiclass"]:
9292
y_true = column_or_1d(y_true)
9393
y_pred = column_or_1d(y_pred)
94+
if y_type == "binary":
95+
unique_values = np.union1d(y_true, y_pred)
96+
if len(unique_values) > 2:
97+
y_type = "multiclass"
9498

9599
if y_type.startswith('multilabel'):
96100
y_true = csr_matrix(y_true)

sklearn/metrics/tests/test_classification.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,8 @@ def test_matthews_corrcoef():
366366
y_true_inv = ["b" if i == "a" else "a" for i in y_true]
367367

368368
assert_almost_equal(matthews_corrcoef(y_true, y_true_inv), -1)
369-
y_true_inv2 = label_binarize(y_true, ["a", "b"]) * -1
369+
y_true_inv2 = label_binarize(y_true, ["a", "b"])
370+
y_true_inv2 = np.where(y_true_inv2, 'a', 'b')
370371
assert_almost_equal(matthews_corrcoef(y_true, y_true_inv2), -1)
371372

372373
# For the zero vector case, the corrcoef cannot be calculated and should
@@ -379,8 +380,7 @@ def test_matthews_corrcoef():
379380

380381
# And also for any other vector with 0 variance
381382
mcc = assert_warns_message(RuntimeWarning, 'invalid value encountered',
382-
matthews_corrcoef, y_true,
383-
rng.randint(-100, 100) * np.ones(20, dtype=int))
383+
matthews_corrcoef, y_true, ['a'] * len(y_true))
384384

385385
# But will output 0
386386
assert_almost_equal(mcc, 0.)
@@ -1267,6 +1267,13 @@ def test__check_targets():
12671267
assert_raise_message(ValueError, msg, _check_targets, y1, y2)
12681268

12691269

1270+
def test__check_targets_multiclass_with_both_y_true_and_y_pred_binary():
1271+
# https://github.com/scikit-learn/scikit-learn/issues/8098
1272+
y_true = [0, 1]
1273+
y_pred = [0, -1]
1274+
assert_equal(_check_targets(y_true, y_pred)[0], 'multiclass')
1275+
1276+
12701277
def test_hinge_loss_binary():
12711278
y_true = np.array([-1, 1, 1, -1])
12721279
pred_decision = np.array([-8.5, 0.5, 1.5, -0.3])

0 commit comments

Comments
 (0)