Skip to content

Commit b34edd3

Browse files
FIX Allow for KMean's attributes to be readonly (scikit-learn#24258)
Co-authored-by: Julien Jerphanion <[email protected]>
1 parent dc45d03 commit b34edd3

File tree

4 files changed

+41
-19
lines changed

4 files changed

+41
-19
lines changed

doc/whats_new/v1.2.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ Changelog
138138
:class:`cluster.AgglomerativeClustering` and will be renamed to `metric` in v1.4.
139139
:pr:`23470` by :user:`Meekail Zain <micky774>`.
140140

141+
- |Fix| :class:`cluster.KMeans` now supports readonly attributes when predicting.
142+
:pr:`24258` by `Thomas Fan`_
143+
141144
:mod:`sklearn.datasets`
142145
.......................
143146

sklearn/cluster/_k_means_lloyd.pyx

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,13 @@ def lloyd_iter_chunked_dense(
5454
5555
centers_new : ndarray of shape (n_clusters, n_features), dtype=floating
5656
Centers after previous iteration, placeholder for the new centers
57-
computed during this iteration.
58-
59-
centers_squared_norms : ndarray of shape (n_clusters,), dtype=floating
60-
Squared L2 norm of the centers.
57+
computed during this iteration. `centers_new` can be `None` if
58+
`update_centers` is False.
6159
6260
weight_in_clusters : ndarray of shape (n_clusters,), dtype=floating
6361
Placeholder for the sums of the weights of every observation assigned
64-
to each center.
62+
to each center. `weight_in_clusters` can be `None` if `update_centers`
63+
is False.
6564
6665
labels : ndarray of shape (n_samples,), dtype=int
6766
labels assignment.
@@ -82,7 +81,7 @@ def lloyd_iter_chunked_dense(
8281
cdef:
8382
int n_samples = X.shape[0]
8483
int n_features = X.shape[1]
85-
int n_clusters = centers_new.shape[0]
84+
int n_clusters = centers_old.shape[0]
8685

8786
# hard-coded number of samples per chunk. Appeared to be close to
8887
# optimal in all situations.
@@ -253,14 +252,13 @@ def lloyd_iter_chunked_sparse(
253252
254253
centers_new : ndarray of shape (n_clusters, n_features), dtype=floating
255254
Centers after previous iteration, placeholder for the new centers
256-
computed during this iteration.
257-
258-
centers_squared_norms : ndarray of shape (n_clusters,), dtype=floating
259-
Squared L2 norm of the centers.
255+
computed during this iteration. `centers_new` can be `None` if
256+
`update_centers` is False.
260257
261258
weight_in_clusters : ndarray of shape (n_clusters,), dtype=floating
262259
Placeholder for the sums of the weights of every observation assigned
263-
to each center.
260+
to each center. `weight_in_clusters` can be `None` if `update_centers`
261+
is False.
264262
265263
labels : ndarray of shape (n_samples,), dtype=int
266264
labels assignment.
@@ -282,7 +280,7 @@ def lloyd_iter_chunked_sparse(
282280
cdef:
283281
int n_samples = X.shape[0]
284282
int n_features = X.shape[1]
285-
int n_clusters = centers_new.shape[0]
283+
int n_clusters = centers_old.shape[0]
286284

287285
# Choose same as for dense. Does not have the same impact since with
288286
# sparse data the pairwise distances matrix is not precomputed.

sklearn/cluster/_kmeans.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -784,8 +784,7 @@ def _labels_inertia(
784784
n_clusters = centers.shape[0]
785785

786786
labels = np.full(n_samples, -1, dtype=np.int32)
787-
weight_in_clusters = np.zeros(n_clusters, dtype=centers.dtype)
788-
center_shift = np.zeros_like(weight_in_clusters)
787+
center_shift = np.zeros(n_clusters, dtype=centers.dtype)
789788

790789
if sp.issparse(X):
791790
_labels = lloyd_iter_chunked_sparse
@@ -795,16 +794,17 @@ def _labels_inertia(
795794
_inertia = _inertia_dense
796795
X = ReadonlyArrayWrapper(X)
797796

797+
centers = ReadonlyArrayWrapper(centers)
798798
_labels(
799799
X,
800800
sample_weight,
801801
x_squared_norms,
802802
centers,
803-
centers,
804-
weight_in_clusters,
805-
labels,
806-
center_shift,
807-
n_threads,
803+
centers_new=None,
804+
weight_in_clusters=None,
805+
labels=labels,
806+
center_shift=center_shift,
807+
n_threads=n_threads,
808808
update_centers=False,
809809
)
810810

sklearn/cluster/tests/test_k_means.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from sklearn.cluster._k_means_common import _inertia_dense
3030
from sklearn.cluster._k_means_common import _inertia_sparse
3131
from sklearn.cluster._k_means_common import _is_same_clustering
32+
from sklearn.utils._testing import create_memmap_backed_data
3233
from sklearn.datasets import make_blobs
3334
from io import StringIO
3435

@@ -1213,3 +1214,23 @@ def test_feature_names_out(Klass, method):
12131214

12141215
names_out = kmeans.get_feature_names_out()
12151216
assert_array_equal([f"{class_name}{i}" for i in range(n_clusters)], names_out)
1217+
1218+
1219+
@pytest.mark.parametrize("is_sparse", [True, False])
1220+
def test_predict_does_not_change_cluster_centers(is_sparse):
1221+
"""Check that predict does not change cluster centers.
1222+
1223+
Non-regression test for gh-24253.
1224+
"""
1225+
X, _ = make_blobs(n_samples=200, n_features=10, centers=10, random_state=0)
1226+
if is_sparse:
1227+
X = sp.csr_matrix(X)
1228+
1229+
kmeans = KMeans()
1230+
y_pred1 = kmeans.fit_predict(X)
1231+
# Make cluster_centers readonly
1232+
kmeans.cluster_centers_ = create_memmap_backed_data(kmeans.cluster_centers_)
1233+
kmeans.labels_ = create_memmap_backed_data(kmeans.labels_)
1234+
1235+
y_pred2 = kmeans.predict(X)
1236+
assert_array_equal(y_pred1, y_pred2)

0 commit comments

Comments
 (0)