Skip to content

Commit 96a02f3

Browse files
glemaitrejnothman
authored andcommitted
[MRG+1] TST cover sparse matrix case for passing through NaN in transformer (scikit-learn#11012)
1 parent 002f95c commit 96a02f3

File tree

1 file changed

+34
-6
lines changed

1 file changed

+34
-6
lines changed

sklearn/preprocessing/tests/test_common.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,33 @@
11
import pytest
22
import numpy as np
33

4+
from scipy import sparse
5+
46
from sklearn.datasets import load_iris
57
from sklearn.model_selection import train_test_split
8+
9+
from sklearn.base import clone
10+
611
from sklearn.preprocessing import QuantileTransformer
712
from sklearn.preprocessing import MinMaxScaler
13+
814
from sklearn.utils.testing import assert_array_equal
915
from sklearn.utils.testing import assert_allclose
1016

1117
iris = load_iris()
1218

1319

20+
def _get_valid_samples_by_column(X, col):
21+
"""Get non NaN samples in column of X"""
22+
return X[:, [col]][~np.isnan(X[:, col])]
23+
24+
1425
@pytest.mark.parametrize(
15-
"est",
16-
[MinMaxScaler(),
17-
QuantileTransformer(n_quantiles=10, random_state=42)]
26+
"est, support_sparse",
27+
[(MinMaxScaler(), False),
28+
(QuantileTransformer(n_quantiles=10, random_state=42), True)]
1829
)
19-
def test_missing_value_handling(est):
30+
def test_missing_value_handling(est, support_sparse):
2031
# check that the preprocessing method let pass nan
2132
rng = np.random.RandomState(42)
2233
X = iris.data.copy()
@@ -43,13 +54,30 @@ def test_missing_value_handling(est):
4354

4455
for i in range(X.shape[1]):
4556
# train only on non-NaN
46-
est.fit(X_train[:, [i]][~np.isnan(X_train[:, i])])
57+
est.fit(_get_valid_samples_by_column(X_train, i))
4758
# check transforming with NaN works even when training without NaN
4859
Xt_col = est.transform(X_test[:, [i]])
4960
assert_array_equal(Xt_col, Xt[:, [i]])
5061
# check non-NaN is handled as before - the 1st column is all nan
5162
if not np.isnan(X_test[:, i]).all():
5263
Xt_col_nonan = est.transform(
53-
X_test[:, [i]][~np.isnan(X_test[:, i])])
64+
_get_valid_samples_by_column(X_test, i))
5465
assert_array_equal(Xt_col_nonan,
5566
Xt_col[~np.isnan(Xt_col.squeeze())])
67+
68+
if support_sparse:
69+
est_dense = clone(est)
70+
est_sparse = clone(est)
71+
72+
Xt_dense = est_dense.fit(X_train).transform(X_test)
73+
Xt_inv_dense = est_dense.inverse_transform(Xt_dense)
74+
for sparse_constructor in (sparse.csr_matrix, sparse.csc_matrix,
75+
sparse.bsr_matrix, sparse.coo_matrix,
76+
sparse.dia_matrix, sparse.dok_matrix,
77+
sparse.lil_matrix):
78+
# check that the dense and sparse inputs lead to the same results
79+
Xt_sparse = (est_sparse.fit(sparse_constructor(X_train))
80+
.transform(sparse_constructor(X_test)))
81+
assert_allclose(Xt_sparse.A, Xt_dense)
82+
Xt_inv_sparse = est_sparse.inverse_transform(Xt_sparse)
83+
assert_allclose(Xt_inv_sparse.A, Xt_inv_dense)

0 commit comments

Comments
 (0)