Skip to content

Commit 8f8c3d2

Browse files
betatimjnothman
authored andcommitted
[MRG] Check for NaN or Inf in metric input (scikit-learn#6976)
1 parent c99aada commit 8f8c3d2

File tree

3 files changed

+32
-0
lines changed

3 files changed

+32
-0
lines changed

sklearn/metrics/classification.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
from ..preprocessing import LabelBinarizer, label_binarize
3232
from ..preprocessing import LabelEncoder
33+
from ..utils import assert_all_finite
3334
from ..utils import check_array
3435
from ..utils import check_consistent_length
3536
from ..utils import column_or_1d
@@ -1843,6 +1844,9 @@ def brier_score_loss(y_true, y_prob, sample_weight=None, pos_label=None):
18431844
"""
18441845
y_true = column_or_1d(y_true)
18451846
y_prob = column_or_1d(y_prob)
1847+
assert_all_finite(y_true)
1848+
assert_all_finite(y_prob)
1849+
18461850
if pos_label is None:
18471851
pos_label = y_true.max()
18481852
y_true = np.array(y_true == pos_label, int)

sklearn/metrics/ranking.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import numpy as np
2424
from scipy.sparse import csr_matrix
2525

26+
from ..utils import assert_all_finite
2627
from ..utils import check_consistent_length
2728
from ..utils import column_or_1d, check_array
2829
from ..utils.multiclass import type_of_target
@@ -296,6 +297,9 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None):
296297
check_consistent_length(y_true, y_score)
297298
y_true = column_or_1d(y_true)
298299
y_score = column_or_1d(y_score)
300+
assert_all_finite(y_true)
301+
assert_all_finite(y_score)
302+
299303
if sample_weight is not None:
300304
sample_weight = column_or_1d(sample_weight)
301305

sklearn/metrics/tests/test_common.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from sklearn.utils.testing import assert_greater
2020
from sklearn.utils.testing import assert_not_equal
2121
from sklearn.utils.testing import assert_raises
22+
from sklearn.utils.testing import assert_raise_message
2223
from sklearn.utils.testing import assert_true
2324
from sklearn.utils.testing import ignore_warnings
2425

@@ -608,6 +609,29 @@ def test_invariance_string_vs_numbers_labels():
608609
assert_raises(ValueError, metric, y1_str.astype('O'), y2)
609610

610611

612+
def test_inf_nan_input():
613+
invalids =[([0, 1], [np.inf, np.inf]),
614+
([0, 1], [np.nan, np.nan]),
615+
([0, 1], [np.nan, np.inf])]
616+
617+
METRICS = dict()
618+
METRICS.update(THRESHOLDED_METRICS)
619+
METRICS.update(REGRESSION_METRICS)
620+
621+
for metric in METRICS.values():
622+
for y_true, y_score in invalids:
623+
assert_raise_message(ValueError,
624+
"contains NaN, infinity",
625+
metric, y_true, y_score)
626+
627+
# Classification metrics all raise a mixed input exception
628+
for metric in CLASSIFICATION_METRICS.values():
629+
for y_true, y_score in invalids:
630+
assert_raise_message(ValueError,
631+
"Can't handle mix of binary and continuous",
632+
metric, y_true, y_score)
633+
634+
611635
@ignore_warnings
612636
def check_single_sample(name):
613637
# Non-regression test: scores should work with a single sample.

0 commit comments

Comments
 (0)