|
5 | 5 | from sys import version_info |
6 | 6 |
|
7 | 7 | import numpy as np |
8 | | -from scipy import interpolate |
| 8 | +from scipy import interpolate, sparse |
9 | 9 |
|
10 | 10 | from sklearn.utils.testing import assert_array_almost_equal |
11 | 11 | from sklearn.utils.testing import assert_almost_equal |
@@ -443,6 +443,23 @@ def test_1d_multioutput_lasso_and_multitask_lasso_cv(): |
443 | 443 | assert_almost_equal(clf.intercept_, clf1.intercept_[0]) |
444 | 444 |
|
445 | 445 |
|
| 446 | +def test_sparse_input_dtype_enet_and_lassocv(): |
| 447 | + X, y, _, _ = build_dataset(n_features=10) |
| 448 | + clf = ElasticNetCV(n_alphas=5) |
| 449 | + clf.fit(sparse.csr_matrix(X), y) |
| 450 | + clf1 = ElasticNetCV(n_alphas=5) |
| 451 | + clf1.fit(sparse.csr_matrix(X, dtype=np.float32), y) |
| 452 | + assert_almost_equal(clf.alpha_, clf1.alpha_, decimal=6) |
| 453 | + assert_almost_equal(clf.coef_, clf1.coef_, decimal=6) |
| 454 | + |
| 455 | + clf = LassoCV(n_alphas=5) |
| 456 | + clf.fit(sparse.csr_matrix(X), y) |
| 457 | + clf1 = LassoCV(n_alphas=5) |
| 458 | + clf1.fit(sparse.csr_matrix(X, dtype=np.float32), y) |
| 459 | + assert_almost_equal(clf.alpha_, clf1.alpha_, decimal=6) |
| 460 | + assert_almost_equal(clf.coef_, clf1.coef_, decimal=6) |
| 461 | + |
| 462 | + |
446 | 463 | if __name__ == '__main__': |
447 | 464 | import nose |
448 | 465 | nose.runmodule() |
0 commit comments