Skip to content

Commit 376c570

Browse files
jakevdpogrisel
authored andcommitted
BUG: use correct algorithm for callable metric
1 parent 1a8b797 commit 376c570

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

sklearn/neighbors/base.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@ def _init_params(self, n_neighbors=None, radius=None,
115115
alg_check = algorithm
116116

117117
if metric not in VALID_METRICS[alg_check]:
118-
# callable metric is valid for brute force, kd_tree, and ball_tree
119-
if callable(metric):
118+
# callable metric is valid for brute force and ball_tree
119+
if callable(metric) and algorithm != 'kd_tree':
120120
pass
121121
else:
122122
raise ValueError("metric '%s' not valid for algorithm '%s'"
@@ -199,8 +199,7 @@ def _fit(self, X):
199199
# and KDTree is generally faster when available
200200
if (self.n_neighbors is None
201201
or self.n_neighbors < self._fit_X.shape[0] // 2):
202-
if (callable(self.effective_metric_)
203-
or self.effective_metric_ in VALID_METRICS['kd_tree']):
202+
if self.effective_metric_ in VALID_METRICS['kd_tree']:
204203
self._fit_method = 'kd_tree'
205204
else:
206205
self._fit_method = 'ball_tree'

sklearn/neighbors/tests/test_neighbors.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,22 @@ def test_neighbors_metrics(n_samples=20, n_features=3,
639639
assert_array_almost_equal(results[0][1], results[1][1])
640640

641641

642+
def test_callable_metric():
643+
metric = lambda x1, x2: np.sqrt(np.sum(x1 ** 2 + x2 ** 2))
644+
645+
X = np.random.random((20, 2))
646+
nbrs1 = neighbors.NearestNeighbors(3, algorithm='auto', metric=metric)
647+
nbrs2 = neighbors.NearestNeighbors(3, algorithm='brute', metric=metric)
648+
649+
nbrs1.fit(X)
650+
nbrs2.fit(X)
651+
652+
dist1, ind1 = nbrs1.kneighbors(X)
653+
dist2, ind2 = nbrs2.kneighbors(X)
654+
655+
assert_array_almost_equal(dist1, dist2)
656+
657+
642658
if __name__ == '__main__':
643659
import nose
644660
nose.runmodule()

0 commit comments

Comments
 (0)