Skip to content

Commit dc7a685

Browse files
jeremiedbbjnothman
authored andcommitted
[MRG] FIX n_iter attribute for KMeans, algorithm=elkan (scikit-learn#11353)
1 parent 9a301b4 commit dc7a685

File tree

4 files changed

+17
-1
lines changed

4 files changed

+17
-1
lines changed

doc/whats_new/v0.20.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,11 @@ Decomposition, manifold learning and clustering
545545
by :user:`Vighnesh Birodkar <vighneshbirodkar>` and
546546
:user:`Olivier Grisel <ogrisel>`.
547547

548+
- Fixed a bug in :func:`cluster.k_means_elkan` where the returned `iteration`
549+
was 1 less than the correct value. Also added the missing `n_iter_` attribute
550+
in the docstring of :class:`cluster.KMeans`. :issue:`11353` by
551+
:user:`Jeremie du Boisberranger <jeremiedbb>`.
552+
548553
Metrics
549554

550555
- Fixed a bug in :func:`metrics.precision_recall_fscore_support`

sklearn/cluster/_k_means_elkan.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,4 +258,4 @@ def k_means_elkan(np.ndarray[floating, ndim=2, mode='c'] X_,
258258
update_labels_distances_inplace(X_p, centers_p, center_half_distances,
259259
labels, lower_bounds, upper_bounds,
260260
n_samples, n_features, n_clusters)
261-
return centers_, labels_, iteration
261+
return centers_, labels_, iteration + 1

sklearn/cluster/k_means_.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,9 @@ class KMeans(BaseEstimator, ClusterMixin, TransformerMixin):
859859
inertia_ : float
860860
Sum of squared distances of samples to their closest cluster center.
861861
862+
n_iter_ : int
863+
Number of iterations run.
864+
862865
Examples
863866
--------
864867

sklearn/cluster/tests/test_k_means.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -979,3 +979,11 @@ def test_check_sample_weight():
979979
assert_equal(_num_samples(X), _num_samples(checked_sample_weight))
980980
assert_almost_equal(checked_sample_weight.sum(), _num_samples(X))
981981
assert_equal(X.dtype, checked_sample_weight.dtype)
982+
983+
984+
def test_iter_attribute():
985+
# Regression test on bad n_iter_ value. Previous bug n_iter_ was one off
986+
# it's right value (#11340).
987+
estimator = KMeans(algorithm="elkan", max_iter=1)
988+
estimator.fit(np.random.rand(10, 10))
989+
assert estimator.n_iter_ == 1

0 commit comments

Comments
 (0)