@@ -54,7 +54,7 @@ def _return_float_dtype(X, Y):
5454 return X , Y , dtype
5555
5656
57- def check_pairwise_arrays (X , Y , precomputed = False ):
57+ def check_pairwise_arrays (X , Y , precomputed = False , dtype = None ):
5858 """ Set X and Y appropriately and checks inputs
5959
6060 If Y is None, it is set as a pointer to X (i.e. not a copy).
@@ -64,9 +64,9 @@ def check_pairwise_arrays(X, Y, precomputed=False):
6464
6565 Specifically, this function first ensures that both X and Y are arrays,
6666 then checks that they are at least two dimensional while ensuring that
67- their elements are floats. Finally, the function checks that the size
68- of the second dimension of the two arrays is equal, or the equivalent
69- check for a precomputed distance matrix.
67+ their elements are floats (or dtype if provided). Finally, the function
68+ checks that the size of the second dimension of the two arrays is equal, or
69+ the equivalent check for a precomputed distance matrix.
7070
7171 Parameters
7272 ----------
@@ -78,6 +78,12 @@ def check_pairwise_arrays(X, Y, precomputed=False):
7878 True if X is to be treated as precomputed distances to the samples in
7979 Y.
8080
81+ dtype : string, type, list of types or None (default=None)
82+ Data type required for X and Y. If None, the dtype will be an
83+ appropriate float type selected by _return_float_dtype.
84+
85+ .. versionadded:: 0.18
86+
8187 Returns
8288 -------
8389 safe_X : {array-like, sparse matrix}, shape (n_samples_a, n_features)
@@ -88,13 +94,21 @@ def check_pairwise_arrays(X, Y, precomputed=False):
8894 If Y was None, safe_Y will be a pointer to X.
8995
9096 """
91- X , Y , dtype = _return_float_dtype (X , Y )
97+ X , Y , dtype_float = _return_float_dtype (X , Y )
98+
99+ warn_on_dtype = dtype is not None
100+ estimator = 'check_pairwise_arrays'
101+ if dtype is None :
102+ dtype = dtype_float
92103
93104 if Y is X or Y is None :
94- X = Y = check_array (X , accept_sparse = 'csr' , dtype = dtype )
105+ X = Y = check_array (X , accept_sparse = 'csr' , dtype = dtype ,
106+ warn_on_dtype = warn_on_dtype , estimator = estimator )
95107 else :
96- X = check_array (X , accept_sparse = 'csr' , dtype = dtype )
97- Y = check_array (Y , accept_sparse = 'csr' , dtype = dtype )
108+ X = check_array (X , accept_sparse = 'csr' , dtype = dtype ,
109+ warn_on_dtype = warn_on_dtype , estimator = estimator )
110+ Y = check_array (Y , accept_sparse = 'csr' , dtype = dtype ,
111+ warn_on_dtype = warn_on_dtype , estimator = estimator )
98112
99113 if precomputed :
100114 if X .shape [1 ] != Y .shape [0 ]:
@@ -1208,7 +1222,11 @@ def pairwise_distances(X, Y=None, metric="euclidean", n_jobs=1, **kwds):
12081222 if issparse (X ) or issparse (Y ):
12091223 raise TypeError ("scipy distance metrics do not"
12101224 " support sparse matrices." )
1211- X , Y = check_pairwise_arrays (X , Y )
1225+
1226+ dtype = bool if metric in PAIRWISE_BOOLEAN_FUNCTIONS else None
1227+
1228+ X , Y = check_pairwise_arrays (X , Y , dtype = dtype )
1229+
12121230 if n_jobs == 1 and X is Y :
12131231 return distance .squareform (distance .pdist (X , metric = metric ,
12141232 ** kwds ))
@@ -1217,6 +1235,20 @@ def pairwise_distances(X, Y=None, metric="euclidean", n_jobs=1, **kwds):
12171235 return _parallel_pairwise (X , Y , func , n_jobs , ** kwds )
12181236
12191237
1238+ # These distances recquire boolean arrays, when using scipy.spatial.distance
1239+ PAIRWISE_BOOLEAN_FUNCTIONS = [
1240+ 'dice' ,
1241+ 'jaccard' ,
1242+ 'kulsinski' ,
1243+ 'matching' ,
1244+ 'rogerstanimoto' ,
1245+ 'russellrao' ,
1246+ 'sokalmichener' ,
1247+ 'sokalsneath' ,
1248+ 'yule' ,
1249+ ]
1250+
1251+
12201252# Helper functions - distance
12211253PAIRWISE_KERNEL_FUNCTIONS = {
12221254 # If updating this dictionary, update the doc in both distance_metrics()
0 commit comments