Skip to content

Commit b86f3d5

Browse files
author
dengemann
committed
ENH: add new Warning class, improve tests, update docs
1 parent dcd6292 commit b86f3d5

File tree

3 files changed

+49
-14
lines changed

3 files changed

+49
-14
lines changed

sklearn/utils/extmath.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from .fixes import qr_economic
1414
from ._logistic_sigmoid import _log_logistic_sigmoid
1515
from ..externals.six.moves import xrange
16-
from .validation import array2d, DataConversionWarning
16+
from .validation import array2d, NonBLASDotWarning
1717

1818

1919
def norm(v):
@@ -80,23 +80,35 @@ def _fast_dot(A, B):
8080
Parameters
8181
----------
8282
A, B: instance of np.ndarray
83-
input matrices.
83+
input matrices. Matrices are supposed to be of the same types
84+
and to have exactly 2 dimensions. Currently only floats are supported.
85+
In case these requirements aren't met np.dot(A, B) is returned
86+
instead. To activate the related warning issued in this case
87+
execute the following lines of code:
88+
89+
>> import warnings
90+
>> from sklearn.utils.validation import NonBLASDotWarning
91+
>> warnings.simplefilter('always', NonBLASDotWarning)
8492
"""
8593

8694
if B.shape[0] != A.shape[A.ndim - 1]: # check adopted from '_dotblas.c'
87-
raise ValueError('matrices are not aligned')
95+
msg = ('Invalid array shapes: A.shape[%d] should be the same as '
96+
'B.shape[0]. Got A.shape=%r B.shape=%r' % (A.ndim - 1,
97+
A.shape, B.shape))
98+
raise ValueError(msg)
8899

89100
if A.dtype != B.dtype or any(x.dtype not in (np.float32, np.float64)
90-
for x in [A, B]):
101+
for x in [A, B]):
91102
warnings.warn('Data must be of same type. Supported types '
92103
'are 32 and 64 bit float. '
93-
'Falling back to np.dot.', DataConversionWarning)
104+
'Falling back to np.dot.', NonBLASDotWarning)
94105
return np.dot(A, B)
106+
95107
if ((A.ndim == 1 or B.ndim == 1) or
96108
(min(A.shape) == 1) or (min(B.shape) == 1) or
97109
(A.ndim != 2) or (B.ndim != 2)):
98110
warnings.warn('Data must be 2D with more than one colum / row.'
99-
'Falling back to np.dot', DataConversionWarning)
111+
'Falling back to np.dot', NonBLASDotWarning)
100112
return np.dot(A, B)
101113

102114
dot = linalg.get_blas_funcs('gemm', (A, B))
@@ -110,8 +122,7 @@ def _fast_dot(A, B):
110122
fast_dot = _fast_dot
111123
except (ImportError, AttributeError):
112124
fast_dot = np.dot
113-
warnings.warn('Could not import BLAS, falling back to np.dot',
114-
DataConversionWarning)
125+
warnings.warn('Could not import BLAS, falling back to np.dot')
115126

116127

117128
def density(w, **kwargs):

sklearn/utils/tests/test_extmath.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from sklearn.utils.extmath import cartesian
2525
from sklearn.utils.extmath import logistic_sigmoid
2626
from sklearn.utils.extmath import fast_dot
27-
from sklearn.utils.validation import DataConversionWarning
27+
from sklearn.utils.validation import NonBLASDotWarning
2828
from sklearn.datasets.samples_generator import make_low_rank_matrix
2929

3030

@@ -303,11 +303,29 @@ def test_fast_dot():
303303
has_blas = False
304304

305305
if has_blas:
306-
for dt1, dt2 in [['f8', 'f4'], ['i4', 'i4']]:
307-
with warnings.catch_warnings(record=True) as w:
308-
warnings.simplefilter("always", DataConversionWarning)
306+
# test dispatch to np.dot
307+
with warnings.catch_warnings(record=True) as w:
308+
warnings.simplefilter('always', NonBLASDotWarning)
309+
# maltyped data
310+
for dt1, dt2 in [['f8', 'f4'], ['i4', 'i4']]:
309311
fast_dot(A.astype(dt1), B.astype(dt2).T)
310-
assert_true(len(w) == 1)
312+
assert_true(type(w.pop(-1)) == NonBLASDotWarning)
313+
# malformed data
314+
# ndim == 0
315+
E = np.empty(0)
316+
fast_dot(E, E)
317+
assert_true(type(w.pop(-1)) == NonBLASDotWarning)
318+
## ndim == 1
319+
fast_dot(A, A[0])
320+
assert_true(type(w.pop(-1)) == NonBLASDotWarning)
321+
## ndim > 2
322+
fast_dot(A.T, np.array([A, A]))
323+
assert_true(type(w.pop(-1)) == NonBLASDotWarning)
324+
## min(shape) == 1
325+
fast_dot(A, A[0, :][None, :])
326+
assert_true(type(w.pop(-1)) == NonBLASDotWarning)
327+
# test for matrix mismatch error
328+
assert_raises(ValueError, fast_dot, A, A)
311329

312330
# test cov-like use case + dtypes
313331
for dtype in ['f8', 'f4']:

sklearn/utils/validation.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,16 @@ class DataConversionWarning(UserWarning):
1616
"A warning on implicit data conversions happening in the code"
1717
pass
1818

19-
2019
warnings.simplefilter("always", DataConversionWarning)
2120

2221

22+
class NonBLASDotWarning(UserWarning):
23+
"A warning on implicit dispatch to numpy.dot"
24+
pass
25+
26+
warnings.simplefilter('ignore', NonBLASDotWarning)
27+
28+
2329
def _assert_all_finite(X):
2430
"""Like assert_all_finite, but only for ndarray."""
2531
if (X.dtype.char in np.typecodes['AllFloat'] and not np.isfinite(X.sum())

0 commit comments

Comments
 (0)