Skip to content

Commit e2eba1f

Browse files
committed
Merge pull request scikit-learn#5688 from amueller/robust_scaler_1column_fix
[MRG+2] fix 1 sparse row scaling in robust scaler
2 parents 0188e68 + 61df16e commit e2eba1f

File tree

2 files changed

+21
-14
lines changed

2 files changed

+21
-14
lines changed

sklearn/preprocessing/data.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
inplace_csr_row_normalize_l2)
2525
from ..utils.sparsefuncs import (inplace_column_scale,
2626
mean_variance_axis, incr_mean_variance_axis,
27-
min_max_axis, inplace_row_scale)
27+
min_max_axis)
2828
from ..utils.validation import check_is_fitted, FLOAT_DTYPES
2929

3030

@@ -984,10 +984,7 @@ def transform(self, X, y=None):
984984

985985
if sparse.issparse(X):
986986
if self.with_scaling:
987-
if X.shape[0] == 1:
988-
inplace_row_scale(X, 1.0 / self.scale_)
989-
elif self.axis == 0:
990-
inplace_column_scale(X, 1.0 / self.scale_)
987+
inplace_column_scale(X, 1.0 / self.scale_)
991988
else:
992989
if self.with_centering:
993990
X -= self.center_
@@ -1013,10 +1010,7 @@ def inverse_transform(self, X):
10131010

10141011
if sparse.issparse(X):
10151012
if self.with_scaling:
1016-
if X.shape[0] == 1:
1017-
inplace_row_scale(X, self.scale_)
1018-
else:
1019-
inplace_column_scale(X, self.scale_)
1013+
inplace_column_scale(X, self.scale_)
10201014
else:
10211015
if self.with_scaling:
10221016
X *= self.scale_

sklearn/preprocessing/tests/test_data.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ def test_standard_scaler_trasform_with_partial_fit():
436436
scaler_incr = StandardScaler()
437437
for i, batch in enumerate(gen_batches(X.shape[0], 1)):
438438

439-
X_sofar = X[:(i+1), :]
439+
X_sofar = X[:(i + 1), :]
440440
chunks_copy = X_sofar.copy()
441441
scaled_batch = StandardScaler().fit_transform(X_sofar)
442442

@@ -784,6 +784,20 @@ def test_robust_scaler_2d_arrays():
784784
assert_array_almost_equal(X_scaled.std(axis=0)[0], 0)
785785

786786

787+
def test_robust_scaler_transform_one_row_csr():
788+
# Check RobustScaler on transforming csr matrix with one row
789+
rng = np.random.RandomState(0)
790+
X = rng.randn(4, 5)
791+
single_row = np.array([[0.1, 1., 2., 0., -1.]])
792+
scaler = RobustScaler(with_centering=False)
793+
scaler = scaler.fit(X)
794+
row_trans = scaler.transform(sparse.csr_matrix(single_row))
795+
row_expected = single_row / scaler.scale_
796+
assert_array_almost_equal(row_trans.toarray(), row_expected)
797+
row_scaled_back = scaler.inverse_transform(row_trans)
798+
assert_array_almost_equal(single_row, row_scaled_back.toarray())
799+
800+
787801
def test_robust_scaler_iris():
788802
X = iris.data
789803
scaler = RobustScaler()
@@ -922,7 +936,7 @@ def test_maxabs_scaler_zero_variance_features():
922936

923937

924938
def test_maxabs_scaler_large_negative_value():
925-
"""Check MaxAbsScaler on toy data with a large negative value"""
939+
# Check MaxAbsScaler on toy data with a large negative value
926940
X = [[0., 1., +0.5, -1.0],
927941
[0., 1., -0.3, -0.5],
928942
[0., 1., -100.0, 0.0],
@@ -938,7 +952,7 @@ def test_maxabs_scaler_large_negative_value():
938952

939953

940954
def test_maxabs_scaler_transform_one_row_csr():
941-
"""Check MaxAbsScaler on transforming csr matrix with one row"""
955+
# Check MaxAbsScaler on transforming csr matrix with one row
942956
X = sparse.csr_matrix([[0.5, 1., 1.]])
943957
scaler = MaxAbsScaler()
944958
scaler = scaler.fit(X)
@@ -1505,8 +1519,7 @@ def test_one_hot_encoder_unknown_transform():
15051519
oh.fit(X)
15061520
assert_array_equal(
15071521
oh.transform(y).toarray(),
1508-
np.array([[0., 0., 0., 0., 1., 0., 0.]])
1509-
)
1522+
np.array([[0., 0., 0., 0., 1., 0., 0.]]))
15101523

15111524
# Raise error if handle_unknown is neither ignore or error.
15121525
oh = OneHotEncoder(handle_unknown='42')

0 commit comments

Comments
 (0)