Skip to content

Commit 8ff19bd

Browse files
committed
Preserve CSR storage format when input is CSR in sparse_center_data
1 parent d9bd47d commit 8ff19bd

File tree

2 files changed

+26
-5
lines changed

2 files changed

+26
-5
lines changed

sklearn/linear_model/base.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from ..utils import as_float_array, atleast2d_or_csr, safe_asarray
2828
from ..utils.extmath import safe_sparse_dot
2929
from ..utils.sparsefuncs import (csc_mean_variance_axis0,
30+
csr_mean_variance_axis0,
31+
inplace_csr_column_scale,
3032
inplace_csc_column_scale)
3133
from .cd_fast import sparse_std
3234

@@ -50,15 +52,25 @@ def sparse_center_data(X, y, fit_intercept, normalize=False):
5052

5153
if fit_intercept:
5254
X_data = X.data
53-
# copy if 'normalize' is True or X is not a csc matrix
54-
X = sp.csc_matrix(X, copy=normalize)
55-
X_mean, X_std = csc_mean_variance_axis0(X)
55+
56+
# we might require not to change the csr matrix sometimes
57+
# store a copy if normalize is True.
58+
if sp.isspmatrix(X) and X.getformat() == 'csr':
59+
X = sp.csr_matrix(X, copy=normalize)
60+
sparse_mean_var_axis0 = csr_mean_variance_axis0
61+
sparse_column_scale = inplace_csr_column_scale
62+
else:
63+
X = sp.csc_matrix(X, copy=normalize)
64+
sparse_mean_var_axis0 = csc_mean_variance_axis0
65+
sparse_column_scale = inplace_csc_column_scale
66+
67+
X_mean, X_std = sparse_mean_var_axis0(X)
5668
if normalize:
5769
X_std = sparse_std(
5870
X.shape[0], X.shape[1],
5971
X_data, X.indices, X.indptr, X_mean)
6072
X_std[X_std == 0] = 1
61-
inplace_csc_column_scale(X, 1. / X_std)
73+
sparse_column_scale(X, 1. / X_std)
6274
else:
6375
X_std = np.ones(X.shape[1])
6476
y_mean = y.mean(axis=0)

sklearn/linear_model/tests/test_base.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from sklearn.utils.testing import assert_array_almost_equal
1010
from sklearn.utils.testing import assert_equal
1111

12-
from sklearn.linear_model.base import LinearRegression
12+
from sklearn.linear_model.base import LinearRegression, sparse_center_data
1313
from sklearn.utils import check_random_state
1414
from sklearn.datasets.samples_generator import make_sparse_uncorrelated
1515
from sklearn.datasets.samples_generator import make_regression
@@ -110,3 +110,12 @@ def test_linear_regression_sparse_multiple_outcome(random_state=0):
110110
ols.fit(X, y.ravel())
111111
y_pred = ols.predict(X)
112112
assert_array_almost_equal(np.vstack((y_pred, y_pred)).T, Y_pred, decimal=3)
113+
114+
115+
def test_csr_sparse_center_data():
116+
"""Test output format of sparse_center_data, when input is csr"""
117+
X, y = make_regression()
118+
X[X < 2.5] = 0.0
119+
csr = sparse.csr_matrix(X)
120+
csr_, y, _, _, _ = sparse_center_data(csr, y, True)
121+
assert_equal(csr_.getformat(), 'csr')

0 commit comments

Comments
 (0)