Skip to content

Commit 7bc92af

Browse files
committed
BUG: fix ward tests
1 parent 7a86168 commit 7bc92af

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

sklearn/cluster/hierarchical.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,13 @@ def ward_tree(X, connectivity=None, n_components=None, copy=True,
7272
X = np.reshape(X, (-1, 1))
7373

7474
if connectivity is None:
75+
if n_clusters is not None:
76+
raise ValueError('Early stopping is implemented only for '
77+
'structured Ward clustering (i.e. with '
78+
'explicit connectivity.')
7579
out = hierarchy.ward(X)
7680
children_ = out[:, :2].astype(np.int)
77-
return children_, 1, n_samples
81+
return children_, 1, n_samples, None
7882

7983
# Compute the number of nodes
8084
if n_components is None:
@@ -331,6 +335,8 @@ def fit(self, X):
331335

332336
n_samples = len(X)
333337
compute_full_tree = self.compute_full_tree
338+
if self.connectivity is None:
339+
compute_full_tree = None
334340
if compute_full_tree == 'auto':
335341
# Early stopping is likely to give a speed up only for
336342
# a large number of clusters. The actual threshold

sklearn/cluster/tests/test_hierarchical.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def test_structured_ward_tree():
2121
mask = np.ones([10, 10], dtype=np.bool)
2222
X = rnd.randn(50, 100)
2323
connectivity = grid_to_graph(*mask.shape)
24-
children, n_components, n_leaves = ward_tree(X.T, connectivity)
24+
children, n_components, n_leaves, parent = ward_tree(X.T, connectivity)
2525
n_nodes = 2 * X.shape[1] - 1
2626
assert_true(len(children) + n_leaves == n_nodes)
2727

@@ -32,7 +32,7 @@ def test_unstructured_ward_tree():
3232
"""
3333
rnd = np.random.RandomState(0)
3434
X = rnd.randn(50, 100)
35-
children, n_nodes, n_leaves = ward_tree(X.T)
35+
children, n_nodes, n_leaves, parent = ward_tree(X.T)
3636
n_nodes = 2 * X.shape[1] - 1
3737
assert_true(len(children) + n_leaves == n_nodes)
3838

@@ -45,7 +45,7 @@ def test_height_ward_tree():
4545
mask = np.ones([10, 10], dtype=np.bool)
4646
X = rnd.randn(50, 100)
4747
connectivity = grid_to_graph(*mask.shape)
48-
children, n_nodes, n_leaves = ward_tree(X.T, connectivity)
48+
children, n_nodes, n_leaves, parent = ward_tree(X.T, connectivity)
4949
n_nodes = 2 * X.shape[1] - 1
5050
assert_true(len(children) + n_leaves == n_nodes)
5151

@@ -109,7 +109,7 @@ def test_scikit_vs_scipy():
109109
out = hierarchy.ward(X)
110110

111111
children_ = out[:, :2].astype(np.int)
112-
children, _, n_leaves = ward_tree(X, connectivity)
112+
children, _, n_leaves, _ = ward_tree(X, connectivity)
113113

114114
cut = _hc_cut(k, children, n_leaves)
115115
cut_ = _hc_cut(k, children_, n_leaves)

0 commit comments

Comments
 (0)