Skip to content

Commit fc474d6

Browse files
authored
ENH Release GIL in DistanceMetric when validating data (scikit-learn#17038)
1 parent a0a23d2 commit fc474d6

File tree

3 files changed

+46
-23
lines changed

3 files changed

+46
-23
lines changed

doc/whats_new/v0.24.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,20 @@ Changelog
8383
- |Enhancement| :func:`tree.plot_tree` now uses colors from the matplotlib
8484
configuration settings. :pr:`17187` by `Andreas Müller`_.
8585

86+
:mod:`sklearn.neighbors`
87+
.............................
88+
89+
- |Efficiency| Speed up ``seuclidean``, ``wminkowski``, ``mahalanobis`` and
90+
``haversine`` metrics in :class:`neighbors.DistanceMetric` by avoiding
91+
unexpected GIL acquiring in Cython when setting ``n_jobs>1`` in
92+
:class:`neighbors.KNeighborsClassifier`,
93+
:class:`neighbors.KNeighborsRegressor`,
94+
:class:`neighbors.RadiusNeighborsClassifier`,
95+
:class:`neighbors.RadiusNeighborsRegressor`,
96+
:func:`metrics.pairwise_distances`
97+
and by validating data out of loops.
98+
:pr:`17038` by :user:`Wenbo Zhao <webber26232>`.
99+
86100
Code and Documentation Contributors
87101
-----------------------------------
88102

sklearn/neighbors/_binary_tree.pxi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,6 +1067,7 @@ cdef class BinaryTree:
10671067
raise ValueError('metric {metric} is not valid for '
10681068
'{BinaryTree}'.format(metric=metric,
10691069
**DOC_DICT))
1070+
self.dist_metric._validate_data(data)
10701071

10711072
# determine number of levels in the tree, and from this
10721073
# the number of nodes in the tree. This results in leaf nodes

sklearn/neighbors/_dist_metrics.pyx

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,14 @@ cdef class DistanceMetric:
292292
if self.__class__ is DistanceMetric:
293293
raise NotImplementedError("DistanceMetric is an abstract class")
294294

295+
def _validate_data(self, X):
296+
"""Validate the input data.
297+
298+
This should be overridden in a base class if a specific input format
299+
is required.
300+
"""
301+
return
302+
295303
cdef DTYPE_t dist(self, DTYPE_t* x1, DTYPE_t* x2,
296304
ITYPE_t size) nogil except -1:
297305
"""Compute the distance between vectors x1 and x2
@@ -386,13 +394,15 @@ cdef class DistanceMetric:
386394
cdef np.ndarray[DTYPE_t, ndim=2, mode='c'] Darr
387395

388396
Xarr = np.asarray(X, dtype=DTYPE, order='C')
397+
self._validate_data(Xarr)
389398
if Y is None:
390399
Darr = np.zeros((Xarr.shape[0], Xarr.shape[0]),
391400
dtype=DTYPE, order='C')
392401
self.pdist(get_memview_DTYPE_2D(Xarr),
393402
get_memview_DTYPE_2D(Darr))
394403
else:
395404
Yarr = np.asarray(Y, dtype=DTYPE, order='C')
405+
self._validate_data(Yarr)
396406
Darr = np.zeros((Xarr.shape[0], Yarr.shape[0]),
397407
dtype=DTYPE, order='C')
398408
self.cdist(get_memview_DTYPE_2D(Xarr),
@@ -449,11 +459,12 @@ cdef class SEuclideanDistance(DistanceMetric):
449459
self.size = self.vec.shape[0]
450460
self.p = 2
451461

462+
def _validate_data(self, X):
463+
if X.shape[1] != self.size:
464+
raise ValueError('SEuclidean dist: size of V does not match')
465+
452466
cdef inline DTYPE_t rdist(self, DTYPE_t* x1, DTYPE_t* x2,
453467
ITYPE_t size) nogil except -1:
454-
if size != self.size:
455-
with gil:
456-
raise ValueError('SEuclidean dist: size of V does not match')
457468
cdef DTYPE_t tmp, d=0
458469
cdef np.intp_t j
459470
for j in range(size):
@@ -597,12 +608,13 @@ cdef class WMinkowskiDistance(DistanceMetric):
597608
self.vec_ptr = get_vec_ptr(self.vec)
598609
self.size = self.vec.shape[0]
599610

611+
def _validate_data(self, X):
612+
if X.shape[1] != self.size:
613+
raise ValueError('WMinkowskiDistance dist: '
614+
'size of w does not match')
615+
600616
cdef inline DTYPE_t rdist(self, DTYPE_t* x1, DTYPE_t* x2,
601617
ITYPE_t size) nogil except -1:
602-
if size != self.size:
603-
with gil:
604-
raise ValueError('WMinkowskiDistance dist: '
605-
'size of w does not match')
606618
cdef DTYPE_t d=0
607619
cdef np.intp_t j
608620
for j in range(size):
@@ -662,12 +674,12 @@ cdef class MahalanobisDistance(DistanceMetric):
662674
self.vec = np.zeros(self.size, dtype=DTYPE)
663675
self.vec_ptr = get_vec_ptr(self.vec)
664676

677+
def _validate_data(self, X):
678+
if X.shape[1] != self.size:
679+
raise ValueError('Mahalanobis dist: size of V does not match')
680+
665681
cdef inline DTYPE_t rdist(self, DTYPE_t* x1, DTYPE_t* x2,
666682
ITYPE_t size) nogil except -1:
667-
if size != self.size:
668-
with gil:
669-
raise ValueError('Mahalanobis dist: size of V does not match')
670-
671683
cdef DTYPE_t tmp, d = 0
672684
cdef np.intp_t i, j
673685

@@ -986,25 +998,21 @@ cdef class HaversineDistance(DistanceMetric):
986998
D(x, y) = 2\\arcsin[\\sqrt{\\sin^2((x1 - y1) / 2)
987999
+ \\cos(x1)\\cos(y1)\\sin^2((x2 - y2) / 2)}]
9881000
"""
1001+
1002+
def _validate_data(self, X):
1003+
if X.shape[1] != 2:
1004+
raise ValueError("Haversine distance only valid "
1005+
"in 2 dimensions")
1006+
9891007
cdef inline DTYPE_t rdist(self, DTYPE_t* x1, DTYPE_t* x2,
9901008
ITYPE_t size) nogil except -1:
991-
if size != 2:
992-
with gil:
993-
raise ValueError("Haversine distance only valid "
994-
"in 2 dimensions")
9951009
cdef DTYPE_t sin_0 = sin(0.5 * (x1[0] - x2[0]))
9961010
cdef DTYPE_t sin_1 = sin(0.5 * (x1[1] - x2[1]))
9971011
return (sin_0 * sin_0 + cos(x1[0]) * cos(x2[0]) * sin_1 * sin_1)
9981012

9991013
cdef inline DTYPE_t dist(self, DTYPE_t* x1, DTYPE_t* x2,
1000-
ITYPE_t size) nogil except -1:
1001-
if size != 2:
1002-
with gil:
1003-
raise ValueError("Haversine distance only valid in 2 dimensions")
1004-
cdef DTYPE_t sin_0 = sin(0.5 * (x1[0] - x2[0]))
1005-
cdef DTYPE_t sin_1 = sin(0.5 * (x1[1] - x2[1]))
1006-
return 2 * asin(sqrt(sin_0 * sin_0
1007-
+ cos(x1[0]) * cos(x2[0]) * sin_1 * sin_1))
1014+
ITYPE_t size) nogil except -1:
1015+
return 2 * asin(sqrt(self.rdist(x1, x2, size)))
10081016

10091017
cdef inline DTYPE_t _rdist_to_dist(self, DTYPE_t rdist) nogil except -1:
10101018
return 2 * asin(sqrt(rdist))

0 commit comments

Comments
 (0)