Skip to content

Commit 1b9ad77

Browse files
committed
TST r2_score float32 overflow fix
1 parent e433d20 commit 1b9ad77

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

sklearn/metrics/tests/test_metrics.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,6 +864,17 @@ def test_r2_one_case_error():
864864
assert_raises(ValueError, r2_score, [0], [0])
865865

866866

867+
def test_r2_overflow():
868+
"""r2_score should not overflow on large arrays of dtype=float32."""
869+
# Simulate a large array by a small one with extreme values,
870+
# but not extreme enough to overflow in the squared difference step.
871+
y_32 = np.repeat(np.sqrt(np.finfo(np.float32).max), 10)
872+
y_64 = y_32.astype(np.float64)
873+
z = np.zeros(10, dtype=np.float32)
874+
875+
assert_equal(r2_score(y_32, z), r2_score(y_64, z))
876+
877+
867878
def test_symmetry():
868879
"""Test the symmetry of score and loss functions"""
869880
y_true, y_pred, _ = make_prediction(binary=True)

0 commit comments

Comments
 (0)