Skip to content

Commit 061803c

Browse files
Erotemicjnothman
authored andcommitted
[MRG+2] Fixed n**2 memory blowup in _labels_inertia_precompute_dense (scikit-learn#7721)
1 parent 94c2094 commit 061803c

File tree

2 files changed

+15
-11
lines changed

2 files changed

+15
-11
lines changed

doc/whats_new.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ New features
2222
Enhancements
2323
............
2424

25+
- :class:`cluster.MiniBatchKMeans` and :class:`cluster.KMeans`
26+
now uses significantly less memory when assigning data points to their
27+
nearest cluster center.
28+
(`#7721 <https://github.com/scikit-learn/scikit-learn/pull/7721>`_)
29+
By `Jon Crall`_.
30+
2531
- Added ``classes_`` attribute to :class:`model_selection.GridSearchCV`
2632
that matches the ``classes_`` attribute of ``best_estimator_``. (`#7661
2733
<https://github.com/scikit-learn/scikit-learn/pull/7661>`_) by `Alyssa

sklearn/cluster/k_means_.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from ..base import BaseEstimator, ClusterMixin, TransformerMixin
2020
from ..metrics.pairwise import euclidean_distances
21+
from ..metrics.pairwise import pairwise_distances_argmin_min
2122
from ..utils.extmath import row_norms, squared_norm, stable_cumsum
2223
from ..utils.sparsefuncs_fast import assign_rows_csr
2324
from ..utils.sparsefuncs import mean_variance_axis
@@ -552,17 +553,14 @@ def _labels_inertia_precompute_dense(X, x_squared_norms, centers, distances):
552553
553554
"""
554555
n_samples = X.shape[0]
555-
k = centers.shape[0]
556-
all_distances = euclidean_distances(centers, X, x_squared_norms,
557-
squared=True)
558-
labels = np.empty(n_samples, dtype=np.int32)
559-
labels.fill(-1)
560-
mindist = np.empty(n_samples)
561-
mindist.fill(np.infty)
562-
for center_id in range(k):
563-
dist = all_distances[center_id]
564-
labels[dist < mindist] = center_id
565-
mindist = np.minimum(dist, mindist)
556+
557+
# Breakup nearest neighbor distance computation into batches to prevent
558+
# memory blowup in the case of a large number of samples and clusters.
559+
# TODO: Once PR #7383 is merged use check_inputs=False in metric_kwargs.
560+
labels, mindist = pairwise_distances_argmin_min(
561+
X=X, Y=centers, metric='euclidean', metric_kwargs={'squared': True})
562+
# cython k-means code assumes int32 inputs
563+
labels = labels.astype(np.int32)
566564
if n_samples == distances.shape[0]:
567565
# distances will be changed in-place
568566
distances[:] = mindist

0 commit comments

Comments
 (0)