|
18 | 18 |
|
19 | 19 | from ..base import BaseEstimator, ClusterMixin, TransformerMixin |
20 | 20 | from ..metrics.pairwise import euclidean_distances |
| 21 | +from ..metrics.pairwise import pairwise_distances_argmin_min |
21 | 22 | from ..utils.extmath import row_norms, squared_norm, stable_cumsum |
22 | 23 | from ..utils.sparsefuncs_fast import assign_rows_csr |
23 | 24 | from ..utils.sparsefuncs import mean_variance_axis |
@@ -552,17 +553,14 @@ def _labels_inertia_precompute_dense(X, x_squared_norms, centers, distances): |
552 | 553 |
|
553 | 554 | """ |
554 | 555 | n_samples = X.shape[0] |
555 | | - k = centers.shape[0] |
556 | | - all_distances = euclidean_distances(centers, X, x_squared_norms, |
557 | | - squared=True) |
558 | | - labels = np.empty(n_samples, dtype=np.int32) |
559 | | - labels.fill(-1) |
560 | | - mindist = np.empty(n_samples) |
561 | | - mindist.fill(np.infty) |
562 | | - for center_id in range(k): |
563 | | - dist = all_distances[center_id] |
564 | | - labels[dist < mindist] = center_id |
565 | | - mindist = np.minimum(dist, mindist) |
| 556 | + |
| 557 | + # Breakup nearest neighbor distance computation into batches to prevent |
| 558 | + # memory blowup in the case of a large number of samples and clusters. |
| 559 | + # TODO: Once PR #7383 is merged use check_inputs=False in metric_kwargs. |
| 560 | + labels, mindist = pairwise_distances_argmin_min( |
| 561 | + X=X, Y=centers, metric='euclidean', metric_kwargs={'squared': True}) |
| 562 | + # cython k-means code assumes int32 inputs |
| 563 | + labels = labels.astype(np.int32) |
566 | 564 | if n_samples == distances.shape[0]: |
567 | 565 | # distances will be changed in-place |
568 | 566 | distances[:] = mindist |
|
0 commit comments