Skip to content

Commit 88e096b

Browse files
jnothmanrth
authored andcommitted
[MRG] TST add test for silhouette from its original publication (scikit-learn#10298)
1 parent 026395f commit 88e096b

File tree

1 file changed

+47
-0
lines changed

1 file changed

+47
-0
lines changed

sklearn/metrics/cluster/tests/test_unsupervised.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,53 @@ def test_cluster_size_1():
8181
assert_array_equal(ss, [0, .5, .5, 0, 1, 1])
8282

8383

84+
def test_silhouette_paper_example():
85+
# Explicitly check per-sample results against Rousseeuw (1987)
86+
# Data from Table 1
87+
lower = [5.58,
88+
7.00, 6.50,
89+
7.08, 7.00, 3.83,
90+
4.83, 5.08, 8.17, 5.83,
91+
2.17, 5.75, 6.67, 6.92, 4.92,
92+
6.42, 5.00, 5.58, 6.00, 4.67, 6.42,
93+
3.42, 5.50, 6.42, 6.42, 5.00, 3.92, 6.17,
94+
2.50, 4.92, 6.25, 7.33, 4.50, 2.25, 6.33, 2.75,
95+
6.08, 6.67, 4.25, 2.67, 6.00, 6.17, 6.17, 6.92, 6.17,
96+
5.25, 6.83, 4.50, 3.75, 5.75, 5.42, 6.08, 5.83, 6.67, 3.67,
97+
4.75, 3.00, 6.08, 6.67, 5.00, 5.58, 4.83, 6.17, 5.67, 6.50, 6.92]
98+
D = np.zeros((12, 12))
99+
D[np.tril_indices(12, -1)] = lower
100+
D += D.T
101+
102+
names = ['BEL', 'BRA', 'CHI', 'CUB', 'EGY', 'FRA', 'IND', 'ISR', 'USA',
103+
'USS', 'YUG', 'ZAI']
104+
105+
# Data from Figure 2
106+
labels1 = [1, 1, 2, 2, 1, 1, 2, 1, 1, 2, 2, 1]
107+
expected1 = {'USA': .43, 'BEL': .39, 'FRA': .35, 'ISR': .30, 'BRA': .22,
108+
'EGY': .20, 'ZAI': .19, 'CUB': .40, 'USS': .34, 'CHI': .33,
109+
'YUG': .26, 'IND': -.04}
110+
score1 = .28
111+
112+
# Data from Figure 3
113+
labels2 = [1, 2, 3, 3, 1, 1, 2, 1, 1, 3, 3, 2]
114+
expected2 = {'USA': .47, 'FRA': .44, 'BEL': .42, 'ISR': .37, 'EGY': .02,
115+
'ZAI': .28, 'BRA': .25, 'IND': .17, 'CUB': .48, 'USS': .44,
116+
'YUG': .31, 'CHI': .31}
117+
score2 = .33
118+
119+
for labels, expected, score in [(labels1, expected1, score1),
120+
(labels2, expected2, score2)]:
121+
expected = [expected[name] for name in names]
122+
# we check to 2dp because that's what's in the paper
123+
assert_almost_equal(expected, silhouette_samples(D, np.array(labels),
124+
metric='precomputed'),
125+
decimal=2)
126+
assert_almost_equal(score, silhouette_score(D, np.array(labels),
127+
metric='precomputed'),
128+
decimal=2)
129+
130+
84131
def test_correct_labelsize():
85132
# Assert 1 < n_labels < n_samples
86133
dataset = datasets.load_iris()

0 commit comments

Comments
 (0)