Skip to content

Commit a1c17af

Browse files
TomDLTNicolasHug
andauthored
FIX sort radius neighbors results when sort_results=True and algorithm="brute" (scikit-learn#18612)
Co-authored-by: Nicolas Hug <[email protected]>
1 parent 1ea9ae2 commit a1c17af

File tree

3 files changed

+58
-7
lines changed

3 files changed

+58
-7
lines changed

doc/whats_new/v0.24.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,14 @@ Changelog
542542
will raise a `ValueError` when fitting on data with all constant features.
543543
:pr:`18370` by :user:`Trevor Waite <trewaite>`.
544544

545+
- |Fix| In methods `radius_neighbors` and
546+
`radius_neighbors_graph` of :class:`neighbors.NearestNeighbors`,
547+
:class:`neighbors.RadiusNeighborsClassifier`,
548+
:class:`neighbors.RadiusNeighborsRegressor`, and
549+
:class:`neighbors.RadiusNeighborsTransformer`, using `sort_results=True` now
550+
correctly sorts the results even when fitting with the "brute" algorithm.
551+
:pr:`18612` by `Tom Dupre la Tour`_.
552+
545553
:mod:`sklearn.neural_network`
546554
.............................
547555

sklearn/neighbors/_base.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -925,10 +925,10 @@ def radius_neighbors(self, X=None, radius=None, return_distance=True,
925925
Whether or not to return the distances.
926926
927927
sort_results : bool, default=False
928-
If True, the distances and indices will be sorted before being
929-
returned. If `False`, the results will not be sorted. If
930-
`return_distance=False`, setting `sort_results=True` will
931-
result in an error.
928+
If True, the distances and indices will be sorted by increasing
929+
distances before being returned. If False, the results may not
930+
be sorted. If `return_distance=False`, setting `sort_results=True`
931+
will result in an error.
932932
933933
.. versionadded:: 0.22
934934
@@ -1021,6 +1021,16 @@ class from an array representing our data set and ask who's
10211021
neigh_ind_list = sum(chunked_results, [])
10221022
results = _to_object_array(neigh_ind_list)
10231023

1024+
if sort_results:
1025+
if not return_distance:
1026+
raise ValueError("return_distance must be True "
1027+
"if sort_results is True.")
1028+
for ii in range(len(neigh_dist)):
1029+
order = np.argsort(neigh_dist[ii], kind='mergesort')
1030+
neigh_ind[ii] = neigh_ind[ii][order]
1031+
neigh_dist[ii] = neigh_dist[ii][order]
1032+
results = neigh_dist, neigh_ind
1033+
10241034
elif self._fit_method in ['ball_tree', 'kd_tree']:
10251035
if issparse(X):
10261036
raise ValueError(
@@ -1097,9 +1107,9 @@ def radius_neighbors_graph(self, X=None, radius=None, mode='connectivity',
10971107
edges are Euclidean distance between points.
10981108
10991109
sort_results : bool, default=False
1100-
If True, the distances and indices will be sorted before being
1101-
returned. If False, the results will not be sorted.
1102-
Only used with mode='distance'.
1110+
If True, in each row of the result, the non-zero entries will be
1111+
sorted by increasing distances. If False, the non-zero entries may
1112+
not be sorted. Only used with mode='distance'.
11031113
11041114
.. versionadded:: 0.22
11051115

sklearn/neighbors/tests/test_neighbors.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,39 @@ def test_radius_neighbors_returns_array_of_objects():
678678
assert_array_equal(neigh_ind, expected_ind)
679679

680680

681+
@pytest.mark.parametrize(["algorithm", "metric"], [("ball_tree", "euclidean"),
682+
("kd_tree", "euclidean"),
683+
("brute", "euclidean"),
684+
("brute", "precomputed")])
685+
def test_radius_neighbors_sort_results(algorithm, metric):
686+
# Test radius_neighbors[_graph] output when sort_result is True
687+
n_samples = 10
688+
rng = np.random.RandomState(42)
689+
X = rng.random_sample((n_samples, 4))
690+
691+
if metric == "precomputed":
692+
X = neighbors.radius_neighbors_graph(X, radius=np.inf, mode="distance")
693+
model = neighbors.NearestNeighbors(algorithm=algorithm, metric=metric)
694+
model.fit(X)
695+
696+
# self.radius_neighbors
697+
distances, indices = model.radius_neighbors(X=X, radius=np.inf,
698+
sort_results=True)
699+
for ii in range(n_samples):
700+
assert_array_equal(distances[ii], np.sort(distances[ii]))
701+
702+
# sort_results=True and return_distance=False
703+
if metric != "precomputed": # no need to raise with precomputed graph
704+
with pytest.raises(ValueError, match="return_distance must be True"):
705+
model.radius_neighbors(X=X, radius=np.inf, sort_results=True,
706+
return_distance=False)
707+
708+
# self.radius_neighbors_graph
709+
graph = model.radius_neighbors_graph(X=X, radius=np.inf, mode="distance",
710+
sort_results=True)
711+
assert _is_sorted_by_data(graph)
712+
713+
681714
def test_RadiusNeighborsClassifier_multioutput():
682715
# Test k-NN classifier on multioutput data
683716
rng = check_random_state(0)

0 commit comments

Comments
 (0)