Skip to content

Commit a015c4a

Browse files
committed
FIX: ENetCV and LassoCV now accept np.float32 input
1 parent 7c9a353 commit a015c4a

File tree

2 files changed

+22
-3
lines changed

2 files changed

+22
-3
lines changed

sklearn/linear_model/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,12 @@ def sparse_center_data(X, y, fit_intercept, normalize=False):
4848
if fit_intercept:
4949
# we might require not to change the csr matrix sometimes
5050
# store a copy if normalize is True.
51+
# Change dtype to float64 since mean_variance_axis0 accepts
52+
# it that way.
5153
if sp.isspmatrix(X) and X.getformat() == 'csr':
52-
X = sp.csr_matrix(X, copy=normalize)
54+
X = sp.csr_matrix(X, copy=normalize, dtype=np.float64)
5355
else:
54-
X = sp.csc_matrix(X, copy=normalize)
56+
X = sp.csc_matrix(X, copy=normalize, dtype=np.float64)
5557

5658
X_mean, X_var = mean_variance_axis0(X)
5759
if normalize:

sklearn/linear_model/tests/test_coordinate_descent.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from sys import version_info
66

77
import numpy as np
8-
from scipy import interpolate
8+
from scipy import interpolate, sparse
99

1010
from sklearn.utils.testing import assert_array_almost_equal
1111
from sklearn.utils.testing import assert_almost_equal
@@ -443,6 +443,23 @@ def test_1d_multioutput_lasso_and_multitask_lasso_cv():
443443
assert_almost_equal(clf.intercept_, clf1.intercept_[0])
444444

445445

446+
def test_sparse_input_dtype_enet_and_lassocv():
447+
X, y, _, _ = build_dataset(n_features=10)
448+
clf = ElasticNetCV(n_alphas=5)
449+
clf.fit(sparse.csr_matrix(X), y)
450+
clf1 = ElasticNetCV(n_alphas=5)
451+
clf1.fit(sparse.csr_matrix(X, dtype=np.float32), y)
452+
assert_almost_equal(clf.alpha_, clf1.alpha_, decimal=6)
453+
assert_almost_equal(clf.coef_, clf1.coef_, decimal=6)
454+
455+
clf = LassoCV(n_alphas=5)
456+
clf.fit(sparse.csr_matrix(X), y)
457+
clf1 = LassoCV(n_alphas=5)
458+
clf1.fit(sparse.csr_matrix(X, dtype=np.float32), y)
459+
assert_almost_equal(clf.alpha_, clf1.alpha_, decimal=6)
460+
assert_almost_equal(clf.coef_, clf1.coef_, decimal=6)
461+
462+
446463
if __name__ == '__main__':
447464
import nose
448465
nose.runmodule()

0 commit comments

Comments
 (0)