Skip to content

Commit 17e8f01

Browse files
committed
ENH / TST better coverage of supervised clustering metrics, slight cleanup
1 parent 270bc8e commit 17e8f01

File tree

2 files changed

+19
-6
lines changed

2 files changed

+19
-6
lines changed

sklearn/metrics/cluster/supervised.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,13 @@ def contingency_matrix(labels_true, labels_pred, eps=None):
7373
# Using coo_matrix to accelerate simple histogram calculation,
7474
# i.e. bins are consecutive integers
7575
# Currently, coo_matrix is faster than histogram2d for simple cases
76-
contingency = np.asarray(coo_matrix((np.ones(class_idx.shape[0]),
77-
(class_idx, cluster_idx)),
78-
shape=(n_classes, n_clusters),
79-
dtype=np.int).todense())
76+
contingency = coo_matrix((np.ones(class_idx.shape[0]),
77+
(class_idx, cluster_idx)),
78+
shape=(n_classes, n_clusters),
79+
dtype=np.int).toarray()
8080
if eps is not None:
81-
# Must be a float matrix to accept float eps
82-
contingency = np.array(contingency, dtype='float') + eps
81+
# don't use += as contingency is integer
82+
contingency = contingency + eps
8383
return contingency
8484

8585

sklearn/metrics/cluster/tests/test_supervised.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,19 @@ def test_adjusted_mutual_info_score():
168168
def test_entropy():
169169
ent = entropy([0, 0, 42.])
170170
assert_almost_equal(ent, 0.6365141, 5)
171+
assert_almost_equal(entropy([]), 1)
172+
173+
174+
def test_contingency_matrix():
175+
labels_a = np.array([1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3])
176+
labels_b = np.array([1, 1, 1, 1, 2, 1, 2, 2, 2, 2, 3, 1, 3, 3, 3, 2, 2])
177+
C = contingency_matrix(labels_a, labels_b)
178+
C2 = np.histogram2d(labels_a, labels_b,
179+
bins=(np.arange(1, 5),
180+
np.arange(1, 5)))[0]
181+
assert_array_almost_equal(C, C2)
182+
C = contingency_matrix(labels_a, labels_b, eps=.1)
183+
assert_array_almost_equal(C, C2 + .1)
171184

172185

173186
def test_exactly_zero_info_score():

0 commit comments

Comments
 (0)