@@ -292,6 +292,14 @@ cdef class DistanceMetric:
292292 if self .__class__ is DistanceMetric :
293293 raise NotImplementedError ("DistanceMetric is an abstract class" )
294294
295+ def _validate_data (self , X ):
296+ """Validate the input data.
297+
298+ This should be overridden in a base class if a specific input format
299+ is required.
300+ """
301+ return
302+
295303 cdef DTYPE_t dist (self , DTYPE_t * x1 , DTYPE_t * x2 ,
296304 ITYPE_t size ) nogil except - 1 :
297305 """Compute the distance between vectors x1 and x2
@@ -386,13 +394,15 @@ cdef class DistanceMetric:
386394 cdef np .ndarray [DTYPE_t , ndim = 2 , mode = 'c' ] Darr
387395
388396 Xarr = np .asarray (X , dtype = DTYPE , order = 'C' )
397+ self ._validate_data (Xarr )
389398 if Y is None :
390399 Darr = np .zeros ((Xarr .shape [0 ], Xarr .shape [0 ]),
391400 dtype = DTYPE , order = 'C' )
392401 self .pdist (get_memview_DTYPE_2D (Xarr ),
393402 get_memview_DTYPE_2D (Darr ))
394403 else :
395404 Yarr = np .asarray (Y , dtype = DTYPE , order = 'C' )
405+ self ._validate_data (Yarr )
396406 Darr = np .zeros ((Xarr .shape [0 ], Yarr .shape [0 ]),
397407 dtype = DTYPE , order = 'C' )
398408 self .cdist (get_memview_DTYPE_2D (Xarr ),
@@ -449,11 +459,12 @@ cdef class SEuclideanDistance(DistanceMetric):
449459 self .size = self .vec .shape [0 ]
450460 self .p = 2
451461
462+ def _validate_data (self , X ):
463+ if X .shape [1 ] != self .size :
464+ raise ValueError ('SEuclidean dist: size of V does not match' )
465+
452466 cdef inline DTYPE_t rdist (self , DTYPE_t * x1 , DTYPE_t * x2 ,
453467 ITYPE_t size ) nogil except - 1 :
454- if size != self .size :
455- with gil :
456- raise ValueError ('SEuclidean dist: size of V does not match' )
457468 cdef DTYPE_t tmp , d = 0
458469 cdef np .intp_t j
459470 for j in range (size ):
@@ -597,12 +608,13 @@ cdef class WMinkowskiDistance(DistanceMetric):
597608 self .vec_ptr = get_vec_ptr (self .vec )
598609 self .size = self .vec .shape [0 ]
599610
611+ def _validate_data (self , X ):
612+ if X .shape [1 ] != self .size :
613+ raise ValueError ('WMinkowskiDistance dist: '
614+ 'size of w does not match' )
615+
600616 cdef inline DTYPE_t rdist (self , DTYPE_t * x1 , DTYPE_t * x2 ,
601617 ITYPE_t size ) nogil except - 1 :
602- if size != self .size :
603- with gil :
604- raise ValueError ('WMinkowskiDistance dist: '
605- 'size of w does not match' )
606618 cdef DTYPE_t d = 0
607619 cdef np .intp_t j
608620 for j in range (size ):
@@ -662,12 +674,12 @@ cdef class MahalanobisDistance(DistanceMetric):
662674 self .vec = np .zeros (self .size , dtype = DTYPE )
663675 self .vec_ptr = get_vec_ptr (self .vec )
664676
677+ def _validate_data (self , X ):
678+ if X .shape [1 ] != self .size :
679+ raise ValueError ('Mahalanobis dist: size of V does not match' )
680+
665681 cdef inline DTYPE_t rdist (self , DTYPE_t * x1 , DTYPE_t * x2 ,
666682 ITYPE_t size ) nogil except - 1 :
667- if size != self .size :
668- with gil :
669- raise ValueError ('Mahalanobis dist: size of V does not match' )
670-
671683 cdef DTYPE_t tmp , d = 0
672684 cdef np .intp_t i , j
673685
@@ -986,25 +998,21 @@ cdef class HaversineDistance(DistanceMetric):
986998 D(x, y) = 2\\ arcsin[\\ sqrt{\\ sin^2((x1 - y1) / 2)
987999 + \\ cos(x1)\\ cos(y1)\\ sin^2((x2 - y2) / 2)}]
9881000 """
1001+
1002+ def _validate_data (self , X ):
1003+ if X .shape [1 ] != 2 :
1004+ raise ValueError ("Haversine distance only valid "
1005+ "in 2 dimensions" )
1006+
9891007 cdef inline DTYPE_t rdist (self , DTYPE_t * x1 , DTYPE_t * x2 ,
9901008 ITYPE_t size ) nogil except - 1 :
991- if size != 2 :
992- with gil :
993- raise ValueError ("Haversine distance only valid "
994- "in 2 dimensions" )
9951009 cdef DTYPE_t sin_0 = sin (0.5 * (x1 [0 ] - x2 [0 ]))
9961010 cdef DTYPE_t sin_1 = sin (0.5 * (x1 [1 ] - x2 [1 ]))
9971011 return (sin_0 * sin_0 + cos (x1 [0 ]) * cos (x2 [0 ]) * sin_1 * sin_1 )
9981012
9991013 cdef inline DTYPE_t dist (self , DTYPE_t * x1 , DTYPE_t * x2 ,
1000- ITYPE_t size ) nogil except - 1 :
1001- if size != 2 :
1002- with gil :
1003- raise ValueError ("Haversine distance only valid in 2 dimensions" )
1004- cdef DTYPE_t sin_0 = sin (0.5 * (x1 [0 ] - x2 [0 ]))
1005- cdef DTYPE_t sin_1 = sin (0.5 * (x1 [1 ] - x2 [1 ]))
1006- return 2 * asin (sqrt (sin_0 * sin_0
1007- + cos (x1 [0 ]) * cos (x2 [0 ]) * sin_1 * sin_1 ))
1014+ ITYPE_t size ) nogil except - 1 :
1015+ return 2 * asin (sqrt (self .rdist (x1 , x2 , size )))
10081016
10091017 cdef inline DTYPE_t _rdist_to_dist (self , DTYPE_t rdist ) nogil except - 1 :
10101018 return 2 * asin (sqrt (rdist ))
0 commit comments