Skip to content

Commit 1aa0f07

Browse files
committed
WIP TST adapt test for sample based metrics
1 parent 19bf47e commit 1aa0f07

File tree

1 file changed

+17
-10
lines changed

1 file changed

+17
-10
lines changed

sklearn/metrics/tests/test_metrics.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1275,7 +1275,8 @@ def test_symmetry():
12751275

12761276
# We shouldn't forget any metrics
12771277
assert_equal(set(SYMMETRIC_METRICS).union(NOT_SYMMETRIC_METRICS,
1278-
THRESHOLDED_METRICS),
1278+
THRESHOLDED_METRICS,
1279+
METRIC_UNDEFINED_MULTICLASS),
12791280
set(ALL_METRICS))
12801281

12811282
assert_equal(
@@ -1307,6 +1308,7 @@ def test_symmetry():
13071308
assert_almost_equal(f(zero_one_score)(y_true, y_pred),
13081309
f(zero_one_score)(y_pred, y_true))
13091310

1311+
13101312
def test_sample_order_invariance():
13111313
y_true, y_pred, _ = make_prediction(binary=True)
13121314
y_true_shuffle, y_pred_shuffle = shuffle(y_true, y_pred, random_state=0)
@@ -1475,32 +1477,37 @@ def test_invariance_string_vs_numbers_labels():
14751477

14761478

14771479
@ignore_warnings
1478-
def check_clf_single_sample(metric):
1480+
def check_single_sample(name):
14791481
"""Non-regression test: scores should work with a single sample.
14801482
14811483
This is important for leave-one-out cross validation.
14821484
Score functions tested are those that formerly called np.squeeze,
14831485
which turns an array of size 1 into a 0-d array (!).
14841486
"""
1487+
metric = ALL_METRICS[name]
1488+
14851489
# assert that no exception is thrown
14861490
for i, j in product([0, 1], repeat=2):
14871491
metric([i], [j])
14881492

1489-
14901493
@ignore_warnings
1491-
def check_clf_single_sample_multioutput(metric):
1494+
def check_single_sample_multioutput(name):
1495+
metric = ALL_METRICS[name]
14921496
for i, j, k, l in product([0, 1], repeat=4):
14931497
metric(np.array([[i, j]]), np.array([[k, l]]))
14941498

14951499

1496-
def test_clf_single_sample():
1497-
for name, metric in (CLASSIFICATION_METRICS.items() +
1498-
REGRESSION_METRICS.items()):
1499-
yield check_clf_single_sample, metric
1500+
def test_single_sample():
1501+
for name in ALL_METRICS:
1502+
if name in METRIC_UNDEFINED_MULTICLASS + THRESHOLDED_METRICS.keys():
1503+
# Those metrics are not always defined with one sample
1504+
# or in multiclass classification
1505+
continue
1506+
1507+
yield check_single_sample, name
15001508

15011509
for name in MULTIOUTPUT_METRICS + MULTILABELS_METRICS:
1502-
metric = ALL_METRICS[name]
1503-
yield check_clf_single_sample_multioutput, metric
1510+
yield check_single_sample_multioutput, name
15041511

15051512

15061513
def test_hinge_loss_binary():

0 commit comments

Comments
 (0)