Skip to content

Commit 2c984ee

Browse files
committed
TST better test-coverage in clustering module
1 parent 3de4442 commit 2c984ee

File tree

8 files changed

+164
-83
lines changed

8 files changed

+164
-83
lines changed

sklearn/cluster/affinity_propagation_.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717

1818
def affinity_propagation(S, preference=None, p=None, convergence_iter=15,
19-
convit=None, max_iter=200,
20-
damping=0.5, copy=True, verbose=False):
19+
convit=None, max_iter=200, damping=0.5, copy=True,
20+
verbose=False):
2121
"""Perform Affinity Propagation Clustering of data
2222
2323
Parameters
@@ -80,16 +80,15 @@ def affinity_propagation(S, preference=None, p=None, convergence_iter=15,
8080
n_samples = S.shape[0]
8181

8282
if S.shape[0] != S.shape[1]:
83-
raise ValueError("S must be a square array (shape=%r)" % S.shape)
83+
raise ValueError("S must be a square array (shape=%s)" % repr(S.shape))
8484

8585
if not p is None:
8686
warnings.warn("p is deprecated and will be removed in version 0.14."
87-
" Use ``preference`` instead.", DeprecationWarning)
87+
"Use ``preference`` instead.", DeprecationWarning)
8888
preference = p
8989

9090
if preference is None:
9191
preference = np.median(S)
92-
9392
if damping < 0.5 or damping >= 1:
9493
raise ValueError('damping must be >= 0.5 and < 1')
9594

@@ -102,8 +101,8 @@ def affinity_propagation(S, preference=None, p=None, convergence_iter=15,
102101
R = np.zeros((n_samples, n_samples)) # Initialize messages
103102

104103
# Remove degeneracies
105-
S += (np.finfo(np.double).eps * S + np.finfo(np.double).tiny * 100) * \
106-
random_state.randn(n_samples, n_samples)
104+
S += ((np.finfo(np.double).eps * S + np.finfo(np.double).tiny * 100) *
105+
random_state.randn(n_samples, n_samples))
107106

108107
# Execute parallel affinity propagation updates
109108
e = np.zeros((n_samples, convergence_iter))
@@ -148,8 +147,8 @@ def affinity_propagation(S, preference=None, p=None, convergence_iter=15,
148147

149148
if it >= convergence_iter:
150149
se = np.sum(e, axis=1)
151-
unconverged = np.sum((se == convergence_iter) +\
152-
(se == 0)) != n_samples
150+
unconverged = (np.sum((se == convergence_iter) + (se == 0))
151+
!= n_samples)
153152
if (not unconverged and (K > 0)) or (it == max_iter):
154153
if verbose:
155154
print "Converged after %d iterations." % it
@@ -246,8 +245,8 @@ class AffinityPropagation(BaseEstimator, ClusterMixin):
246245
"""
247246

248247
def __init__(self, damping=.5, max_iter=200, convergence_iter=15,
249-
convit=None, copy=True,
250-
preference=None, p=None, affinity='euclidean', verbose=False):
248+
convit=None, copy=True, preference=None, p=None,
249+
affinity='euclidean', verbose=False):
251250

252251
if convit is not None:
253252
warnings.warn("``convit`` is deprectaed and will be removed in "
@@ -262,7 +261,7 @@ def __init__(self, damping=.5, max_iter=200, convergence_iter=15,
262261
self.verbose = verbose
263262
if not p is None:
264263
warnings.warn("p is deprecated and will be removed in version 0.14"
265-
". Use ``preference`` instead.", DeprecationWarning)
264+
". Use ``preference`` instead.", DeprecationWarning)
266265
preference = p
267266

268267
self.preference = preference
@@ -295,7 +294,7 @@ def fit(self, X):
295294
self.affinity_matrix_ = -euclidean_distances(X, squared=True)
296295
else:
297296
raise ValueError("Affinity must be 'precomputed' or "
298-
"'euclidean'. Got %s instead" % str(self.affinity))
297+
"'euclidean'. Got %s instead" % str(self.affinity))
299298

300299
self.cluster_centers_indices_, self.labels_ = affinity_propagation(
301300
self.affinity_matrix_, self.preference,

sklearn/cluster/hierarchical.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ def fit(self, X):
347347
memory = self.memory
348348
X = array2d(X)
349349
if isinstance(memory, basestring):
350-
memory = Memory(cachedir=memory)
350+
memory = Memory(cachedir=memory, verbose=0)
351351

352352
if not self.connectivity is None:
353353
if not sparse.issparse(self.connectivity):

sklearn/cluster/k_means_.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -155,22 +155,22 @@ def k_means(X, n_clusters, init='k-means++', precompute_distances=True,
155155
156156
Parameters
157157
----------
158-
X: array-like or sparse matrix, shape (n_samples, n_features)
158+
X : array-like or sparse matrix, shape (n_samples, n_features)
159159
The observations to cluster.
160160
161-
n_clusters: int
161+
n_clusters : int
162162
The number of clusters to form as well as the number of
163163
centroids to generate.
164164
165-
max_iter: int, optional, default 300
165+
max_iter : int, optional, default 300
166166
Maximum number of iterations of the k-means algorithm to run.
167167
168-
n_init: int, optional, default: 10
168+
n_init : int, optional, default: 10
169169
Number of time the k-means algorithm will be run with different
170170
centroid seeds. The final results will be the best output of
171171
n_init consecutive runs in terms of inertia.
172172
173-
init: {'k-means++', 'random', or ndarray, or a callable}, optional
173+
init : {'k-means++', 'random', or ndarray, or a callable}, optional
174174
Method for initialization, default to 'k-means++':
175175
176176
'k-means++' : selects initial cluster centers for k-mean
@@ -186,25 +186,25 @@ def k_means(X, n_clusters, init='k-means++', precompute_distances=True,
186186
If a callable is passed, it should take arguments X, k and
187187
and a random state and return an initialization.
188188
189-
tol: float, optional
189+
tol : float, optional
190190
The relative increment in the results before declaring convergence.
191191
192-
verbose: boolean, optional
193-
Verbosity mode
192+
verbose : boolean, optional
193+
Verbosity mode.
194194
195-
random_state: integer or numpy.RandomState, optional
195+
random_state : integer or numpy.RandomState, optional
196196
The generator used to initialize the centers. If an integer is
197197
given, it fixes the seed. Defaults to the global numpy random
198198
number generator.
199199
200-
copy_x: boolean, optional
200+
copy_x : boolean, optional
201201
When pre-computing distances it is more numerically accurate to center
202202
the data first. If copy_x is True, then the original data is not
203203
modified. If False, the original data is modified, and put back before
204204
the function returns, but small numerical differences may be introduced
205205
by subtracting and then adding the data mean.
206206
207-
n_jobs: int
207+
n_jobs : int
208208
The number of jobs to use for the computation. This works by breaking
209209
down the pairwise matrix into n_jobs even slices and computing them in
210210
parallel.
@@ -216,14 +216,14 @@ def k_means(X, n_clusters, init='k-means++', precompute_distances=True,
216216
217217
Returns
218218
-------
219-
centroid: float ndarray with shape (k, n_features)
219+
centroid : float ndarray with shape (k, n_features)
220220
Centroids found at the last iteration of k-means.
221221
222-
label: integer ndarray with shape (n_samples,)
222+
label : integer ndarray with shape (n_samples,)
223223
label[i] is the code or index of the centroid the
224224
i'th observation is closest to.
225225
226-
inertia: float
226+
inertia : float
227227
The final value of the inertia criterion (sum of squared distances to
228228
the closest centroid for all observations in the training set).
229229

sklearn/cluster/tests/test_affinity_propagation.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
import numpy as np
77

8-
from sklearn.utils.testing import assert_equal, assert_array_equal
8+
from sklearn.utils.testing import (assert_equal, assert_array_equal,
9+
assert_raises)
910
from sklearn.cluster.affinity_propagation_ import AffinityPropagation
1011
from sklearn.cluster.affinity_propagation_ import affinity_propagation
1112
from sklearn.datasets.samples_generator import make_blobs
@@ -25,7 +26,7 @@ def test_affinity_propagation():
2526
preference = np.median(S) * 10
2627
# Compute Affinity Propagation
2728
cluster_centers_indices, labels = affinity_propagation(S,
28-
preference=preference)
29+
preference=preference)
2930

3031
n_clusters_ = len(cluster_centers_indices)
3132

@@ -34,7 +35,7 @@ def test_affinity_propagation():
3435
af = AffinityPropagation(preference=preference, affinity="precomputed")
3536
labels_precomputed = af.fit(S).labels_
3637

37-
af = AffinityPropagation(preference=preference)
38+
af = AffinityPropagation(preference=preference, verbose=True)
3839
labels = af.fit(X).labels_
3940

4041
assert_array_equal(labels, labels_precomputed)
@@ -47,5 +48,11 @@ def test_affinity_propagation():
4748

4849
# Test also with no copy
4950
_, labels_no_copy = affinity_propagation(S, preference=preference,
50-
copy=False)
51+
copy=False)
5152
assert_array_equal(labels, labels_no_copy)
53+
54+
# Test input validation
55+
assert_raises(ValueError, affinity_propagation, S[:, :-1])
56+
assert_raises(ValueError, affinity_propagation, S, damping=0)
57+
af = AffinityPropagation(affinity="unknown")
58+
assert_raises(ValueError, af.fit, X)

sklearn/cluster/tests/test_hierarchical.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# Authors: Vincent Michel, 2010, Gael Varoquaux 2012
66
# License: BSD-like
77
import warnings
8+
from tempfile import mkdtemp
89

910
import numpy as np
1011
from scipy import sparse
@@ -78,6 +79,10 @@ def test_ward_clustering():
7879
connectivity = grid_to_graph(*mask.shape)
7980
clustering = Ward(n_clusters=10, connectivity=connectivity)
8081
clustering.fit(X)
82+
# test caching
83+
clustering = Ward(n_clusters=10, connectivity=connectivity,
84+
memory=mkdtemp())
85+
clustering.fit(X)
8186
labels = clustering.labels_
8287
assert_true(np.size(np.unique(labels)) == 10)
8388
# Check that we obtain the same solution with early-stopping of the
@@ -94,7 +99,7 @@ def test_ward_clustering():
9499
assert_raises(TypeError, clustering.fit, X)
95100
clustering = Ward(n_clusters=10,
96101
connectivity=sparse.lil_matrix(
97-
connectivity.todense()[:10, :10]))
102+
connectivity.todense()[:10, :10]))
98103
assert_raises(ValueError, clustering.fit, X)
99104

100105

@@ -166,7 +171,7 @@ def test_connectivity_popagation():
166171
(.018, .153), (.018, .153), (.018, .153),
167172
(.018, .153), (.018, .153), (.018, .153),
168173
(.018, .152), (.018, .149), (.018, .144),
169-
])
174+
])
170175
nn = NearestNeighbors(n_neighbors=10, warn_on_equidistant=False).fit(X)
171176
connectivity = nn.kneighbors_graph(X)
172177
ward = Ward(n_clusters=4, connectivity=connectivity)

0 commit comments

Comments
 (0)