Skip to content

Commit aed81ed

Browse files
authored
MNT Add more sample weight checks in regression metric common tests (scikit-learn#31726)
1 parent fc95dd2 commit aed81ed

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

sklearn/metrics/tests/test_common.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1614,6 +1614,19 @@ def test_regression_with_invalid_sample_weight(name):
16141614
with pytest.raises(ValueError, match="Found input variables with inconsistent"):
16151615
metric(y_true, y_pred, sample_weight=sample_weight)
16161616

1617+
sample_weight = random_state.random_sample(size=(n_samples,))
1618+
sample_weight[0] = np.inf
1619+
with pytest.raises(ValueError, match="Input sample_weight contains infinity"):
1620+
metric(y_true, y_pred, sample_weight=sample_weight)
1621+
1622+
sample_weight[0] = np.nan
1623+
with pytest.raises(ValueError, match="Input sample_weight contains NaN"):
1624+
metric(y_true, y_pred, sample_weight=sample_weight)
1625+
1626+
sample_weight = np.array([1 + 2j, 3 + 4j, 5 + 7j])
1627+
with pytest.raises(ValueError, match="Complex data not supported"):
1628+
metric(y_true[:3], y_pred[:3], sample_weight=sample_weight)
1629+
16171630
sample_weight = random_state.random_sample(size=(n_samples * 2,)).reshape(
16181631
(n_samples, 2)
16191632
)

0 commit comments

Comments
 (0)