Skip to content

Commit df1e711

Browse files
committed
DOC: ward docstring and testing
1 parent 7bc92af commit df1e711

File tree

3 files changed

+43
-7
lines changed

3 files changed

+43
-7
lines changed

doc/tutorial/statistical_inference/unsupervised_learning.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ transposed data.
214214
>>> agglo = cluster.WardAgglomeration(connectivity=connectivity,
215215
... n_clusters=32)
216216
>>> agglo.fit(X) # doctest: +ELLIPSIS
217-
WardAgglomeration(connectivity=...
217+
WardAgglomeration(compute_full_tree='auto', ...
218218
>>> X_reduced = agglo.transform(X)
219219

220220
>>> X_approx = agglo.inverse_transform(X_reduced)

sklearn/cluster/hierarchical.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,29 @@ def ward_tree(X, connectivity=None, n_components=None, copy=True,
5454
Make a copy of connectivity or work inplace. If connectivity
5555
is not of LIL type there will be a copy in any case.
5656
57+
n_clusters : int (optional)
58+
Stop early the construction of the tree at n_clusters. This is
59+
useful to decrease computation time if the number of clusters is
60+
not small compared to the number of samples. In this case, the
61+
complete tree is not computed, thus the 'children' output is of
62+
limited use, and the 'parents' output should rather be used.
63+
This option is valid only when specifying a connectivity matrix.
64+
5765
Returns
5866
-------
59-
children : list of pairs. Lenght of n_nodes
60-
list of the children of each nodes.
61-
Leaves of the tree have empty list of children.
67+
children : 2D array, shape (n_nodes, 2)
68+
list of the children of each nodes.
69+
Leaves of the tree have empty list of children.
6270
6371
n_components : sparse matrix.
6472
The number of connected components in the graph.
6573
6674
n_leaves : int
6775
The number of leaves in the tree
76+
77+
parents : 1D array, shape (n_nodes, ) or None
78+
The parent of each node. Only returned when a connectivity matrix
79+
is specified, elsewhere 'None' is returned.
6880
"""
6981
X = np.asarray(X)
7082
n_samples, n_features = X.shape
@@ -73,9 +85,9 @@ def ward_tree(X, connectivity=None, n_components=None, copy=True,
7385

7486
if connectivity is None:
7587
if n_clusters is not None:
76-
raise ValueError('Early stopping is implemented only for '
88+
warnings.warn('Early stopping is implemented only for '
7789
'structured Ward clustering (i.e. with '
78-
'explicit connectivity.')
90+
'explicit connectivity.', stacklevel=2)
7991
out = hierarchy.ward(X)
8092
children_ = out[:, :2].astype(np.int)
8193
return children_, 1, n_samples, None
@@ -284,6 +296,15 @@ class Ward(BaseEstimator):
284296
The number of connected components in the graph defined by the \
285297
connectivity matrix. If not set, it is estimated.
286298
299+
compute_full_tree: bool or 'auto' (optional)
300+
Stop early the construction of the tree at n_clusters. This is
301+
useful to decrease computation time if the number of clusters is
302+
not small compared to the number of samples. This option is
303+
useful only when specifying a connectivity matrix. Note also that
304+
when varying the number of cluster and using caching, it may
305+
be advantageous to compute the full tree.
306+
307+
287308
Attributes
288309
----------
289310
`children_` : array-like, shape = [n_nodes, 2]
@@ -393,6 +414,15 @@ class WardAgglomeration(AgglomerationTransform, Ward):
393414
The number of connected components in the graph defined by the
394415
connectivity matrix. If not set, it is estimated.
395416
417+
compute_full_tree: bool or 'auto' (optional)
418+
Stop early the construction of the tree at n_clusters. This is
419+
useful to decrease computation time if the number of clusters is
420+
not small compared to the number of samples. This option is
421+
useful only when specifying a connectivity matrix. Note also that
422+
when varying the number of cluster and using caching, it may
423+
be advantageous to compute the full tree.
424+
425+
396426
Attributes
397427
----------
398428
`children_` : array-like, shape = [n_nodes, 2]

sklearn/cluster/tests/test_hierarchical.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,13 @@ def test_ward_clustering():
6060
connectivity = grid_to_graph(*mask.shape)
6161
clustering = Ward(n_clusters=10, connectivity=connectivity)
6262
clustering.fit(X)
63-
assert_true(np.size(np.unique(clustering.labels_)) == 10)
63+
labels = clustering.labels_
64+
assert_true(np.size(np.unique(labels)) == 10)
65+
# Check that we obtain the same solution with early-stopping of the
66+
# tree building
67+
clustering.compute_full_tree = False
68+
clustering.fit(X)
69+
np.testing.assert_array_equal(clustering.labels_, labels)
6470

6571

6672
def test_ward_agglomeration():

0 commit comments

Comments
 (0)