Skip to content

Commit 45d9182

Browse files
guiniolTomDLT
authored andcommitted
[MRG+1] Fix float size in as_float_array (scikit-learn#8598)
* Fix float size in as_float_array * Add tests for small ints in as_float_array * Add test for object dtype with minor tweaks
1 parent 05feb7e commit 45d9182

File tree

2 files changed

+21
-5
lines changed

2 files changed

+21
-5
lines changed

sklearn/utils/tests/test_validation.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,28 @@ def test_as_float_array():
4242
# Test function for as_float_array
4343
X = np.ones((3, 10), dtype=np.int32)
4444
X = X + np.arange(10, dtype=np.int32)
45-
# Checks that the return type is ok
4645
X2 = as_float_array(X, copy=False)
47-
np.testing.assert_equal(X2.dtype, np.float32)
46+
assert_equal(X2.dtype, np.float32)
4847
# Another test
4948
X = X.astype(np.int64)
5049
X2 = as_float_array(X, copy=True)
5150
# Checking that the array wasn't overwritten
5251
assert_true(as_float_array(X, False) is not X)
53-
# Checking that the new type is ok
54-
np.testing.assert_equal(X2.dtype, np.float64)
52+
assert_equal(X2.dtype, np.float64)
53+
# Test int dtypes <= 32bit
54+
tested_dtypes = [np.bool,
55+
np.int8, np.int16, np.int32,
56+
np.uint8, np.uint16, np.uint32]
57+
for dtype in tested_dtypes:
58+
X = X.astype(dtype)
59+
X2 = as_float_array(X)
60+
assert_equal(X2.dtype, np.float32)
61+
62+
# Test object dtype
63+
X = X.astype(object)
64+
X2 = as_float_array(X, copy=True)
65+
assert_equal(X2.dtype, np.float64)
66+
5567
# Here, X is of the right type, it shouldn't be modified
5668
X = np.ones((3, 2), dtype=np.float32)
5769
assert_true(as_float_array(X, copy=False) is X)

sklearn/utils/validation.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,11 @@ def as_float_array(X, copy=True, force_all_finite=True):
8080
elif X.dtype in [np.float32, np.float64]: # is numpy array
8181
return X.copy('F' if X.flags['F_CONTIGUOUS'] else 'C') if copy else X
8282
else:
83-
return X.astype(np.float32 if X.dtype == np.int32 else np.float64)
83+
if X.dtype.kind in 'uib' and X.dtype.itemsize <= 4:
84+
return_dtype = np.float32
85+
else:
86+
return_dtype = np.float64
87+
return X.astype(return_dtype)
8488

8589

8690
def _is_arraylike(x):

0 commit comments

Comments
 (0)