Skip to content

Commit 622f912

Browse files
aliddellqinhanmin2014
authored andcommitted
FIX Solves integer overflow in fowlkes_mallow_score (scikit-learn#10844)
1 parent 67cc975 commit 622f912

File tree

3 files changed

+13
-5
lines changed

3 files changed

+13
-5
lines changed

doc/whats_new/v0.20.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,12 @@ Metrics
421421
:func:`mutual_info_score`.
422422
:issue:`9772` by :user:`Kumar Ashutosh <thechargedneutron>`.
423423

424+
- Fixed a bug in :func:`metrics.cluster.fowlkes_mallows_score` to avoid integer
425+
overflow. Casted return value of `contingency_matrix` to `int64` and computed
426+
product of square roots rather than square root of product.
427+
:issue:`9515` by :user:`Alan Liddell <aliddell>` and
428+
:user:`Manh Dao <manhdao>`.
429+
424430
Neighbors
425431

426432
- Fixed a bug so ``predict`` in :class:`neighbors.RadiusNeighborsRegressor` can

sklearn/metrics/cluster/supervised.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -852,11 +852,12 @@ def fowlkes_mallows_score(labels_true, labels_pred, sparse=False):
852852
labels_true, labels_pred = check_clusterings(labels_true, labels_pred)
853853
n_samples, = labels_true.shape
854854

855-
c = contingency_matrix(labels_true, labels_pred, sparse=True)
855+
c = contingency_matrix(labels_true, labels_pred,
856+
sparse=True).astype(np.int64)
856857
tk = np.dot(c.data, c.data) - n_samples
857858
pk = np.sum(np.asarray(c.sum(axis=0)).ravel() ** 2) - n_samples
858859
qk = np.sum(np.asarray(c.sum(axis=1)).ravel() ** 2) - n_samples
859-
return tk / np.sqrt(pk * qk) if tk != 0. else 0.
860+
return np.sqrt(tk / pk) * np.sqrt(tk / qk) if tk != 0. else 0.
860861

861862

862863
def entropy(labels):

sklearn/metrics/cluster/tests/test_supervised.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,15 +173,16 @@ def test_expected_mutual_info_overflow():
173173
assert expected_mutual_information(np.array([[70000]]), 70000) <= 1
174174

175175

176-
def test_int_overflow_mutual_info_score():
177-
# Test overflow in mutual_info_classif
176+
def test_int_overflow_mutual_info_fowlkes_mallows_score():
177+
# Test overflow in mutual_info_classif and fowlkes_mallows_score
178178
x = np.array([1] * (52632 + 2529) + [2] * (14660 + 793) + [3] * (3271 +
179179
204) + [4] * (814 + 39) + [5] * (316 + 20))
180180
y = np.array([0] * 52632 + [1] * 2529 + [0] * 14660 + [1] * 793 +
181181
[0] * 3271 + [1] * 204 + [0] * 814 + [1] * 39 + [0] * 316 +
182182
[1] * 20)
183183

184-
assert_all_finite(mutual_info_score(x.ravel(), y.ravel()))
184+
assert_all_finite(mutual_info_score(x, y))
185+
assert_all_finite(fowlkes_mallows_score(x, y))
185186

186187

187188
def test_entropy():

0 commit comments

Comments
 (0)