2727from ..utils import as_float_array , atleast2d_or_csr , safe_asarray
2828from ..utils .extmath import safe_sparse_dot
2929from ..utils .sparsefuncs import (csc_mean_variance_axis0 ,
30+ csr_mean_variance_axis0 ,
31+ inplace_csr_column_scale ,
3032 inplace_csc_column_scale )
3133from .cd_fast import sparse_std
3234
@@ -50,15 +52,25 @@ def sparse_center_data(X, y, fit_intercept, normalize=False):
5052
5153 if fit_intercept :
5254 X_data = X .data
53- # copy if 'normalize' is True or X is not a csc matrix
54- X = sp .csc_matrix (X , copy = normalize )
55- X_mean , X_std = csc_mean_variance_axis0 (X )
55+
56+ # we might require not to change the csr matrix sometimes
57+ # store a copy if normalize is True.
58+ if sp .isspmatrix (X ) and X .getformat () == 'csr' :
59+ X = sp .csr_matrix (X , copy = normalize )
60+ sparse_mean_var_axis0 = csr_mean_variance_axis0
61+ sparse_column_scale = inplace_csr_column_scale
62+ else :
63+ X = sp .csc_matrix (X , copy = normalize )
64+ sparse_mean_var_axis0 = csc_mean_variance_axis0
65+ sparse_column_scale = inplace_csc_column_scale
66+
67+ X_mean , X_std = sparse_mean_var_axis0 (X )
5668 if normalize :
5769 X_std = sparse_std (
5870 X .shape [0 ], X .shape [1 ],
5971 X_data , X .indices , X .indptr , X_mean )
6072 X_std [X_std == 0 ] = 1
61- inplace_csc_column_scale (X , 1. / X_std )
73+ sparse_column_scale (X , 1. / X_std )
6274 else :
6375 X_std = np .ones (X .shape [1 ])
6476 y_mean = y .mean (axis = 0 )
0 commit comments