11import pytest
22import numpy as np
33
4+ from scipy import sparse
5+
46from sklearn .datasets import load_iris
57from sklearn .model_selection import train_test_split
8+
9+ from sklearn .base import clone
10+
611from sklearn .preprocessing import QuantileTransformer
712from sklearn .preprocessing import MinMaxScaler
13+
814from sklearn .utils .testing import assert_array_equal
915from sklearn .utils .testing import assert_allclose
1016
1117iris = load_iris ()
1218
1319
20+ def _get_valid_samples_by_column (X , col ):
21+ """Get non NaN samples in column of X"""
22+ return X [:, [col ]][~ np .isnan (X [:, col ])]
23+
24+
1425@pytest .mark .parametrize (
15- "est" ,
16- [MinMaxScaler (),
17- QuantileTransformer (n_quantiles = 10 , random_state = 42 )]
26+ "est, support_sparse " ,
27+ [( MinMaxScaler (), False ),
28+ ( QuantileTransformer (n_quantiles = 10 , random_state = 42 ), True )]
1829)
19- def test_missing_value_handling (est ):
30+ def test_missing_value_handling (est , support_sparse ):
2031 # check that the preprocessing method let pass nan
2132 rng = np .random .RandomState (42 )
2233 X = iris .data .copy ()
@@ -43,13 +54,30 @@ def test_missing_value_handling(est):
4354
4455 for i in range (X .shape [1 ]):
4556 # train only on non-NaN
46- est .fit (X_train [:, [ i ]][ ~ np . isnan (X_train [: , i ])] )
57+ est .fit (_get_valid_samples_by_column (X_train , i ) )
4758 # check transforming with NaN works even when training without NaN
4859 Xt_col = est .transform (X_test [:, [i ]])
4960 assert_array_equal (Xt_col , Xt [:, [i ]])
5061 # check non-NaN is handled as before - the 1st column is all nan
5162 if not np .isnan (X_test [:, i ]).all ():
5263 Xt_col_nonan = est .transform (
53- X_test [:, [ i ]][ ~ np . isnan (X_test [: , i ])] )
64+ _get_valid_samples_by_column (X_test , i ) )
5465 assert_array_equal (Xt_col_nonan ,
5566 Xt_col [~ np .isnan (Xt_col .squeeze ())])
67+
68+ if support_sparse :
69+ est_dense = clone (est )
70+ est_sparse = clone (est )
71+
72+ Xt_dense = est_dense .fit (X_train ).transform (X_test )
73+ Xt_inv_dense = est_dense .inverse_transform (Xt_dense )
74+ for sparse_constructor in (sparse .csr_matrix , sparse .csc_matrix ,
75+ sparse .bsr_matrix , sparse .coo_matrix ,
76+ sparse .dia_matrix , sparse .dok_matrix ,
77+ sparse .lil_matrix ):
78+ # check that the dense and sparse inputs lead to the same results
79+ Xt_sparse = (est_sparse .fit (sparse_constructor (X_train ))
80+ .transform (sparse_constructor (X_test )))
81+ assert_allclose (Xt_sparse .A , Xt_dense )
82+ Xt_inv_sparse = est_sparse .inverse_transform (Xt_sparse )
83+ assert_allclose (Xt_inv_sparse .A , Xt_inv_dense )
0 commit comments