Skip to content

Commit 989f9c7

Browse files
jnothmanglemaitre
authored andcommitted
FIX check_methods_subset_invariance where estimator produces sparse output (scikit-learn#11173)
1 parent 86ab5a4 commit 989f9c7

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

sklearn/utils/estimator_checks.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -676,7 +676,7 @@ def check_fit2d_predict1d(name, estimator_orig):
676676
getattr(estimator, method), X[0])
677677

678678

679-
def _apply_func(func, X):
679+
def _apply_on_subsets(func, X):
680680
# apply function on the whole set and on mini batches
681681
result_full = func(X)
682682
n_features = X.shape[1]
@@ -687,6 +687,9 @@ def _apply_func(func, X):
687687
result_full = result_full[0]
688688
result_by_batch = list(map(lambda x: x[0], result_by_batch))
689689

690+
if sparse.issparse(result_full):
691+
result_full = result_full.A
692+
result_by_batch = [x.A for x in result_by_batch]
690693
return np.ravel(result_full), np.ravel(result_by_batch)
691694

692695

@@ -722,7 +725,7 @@ def check_methods_subset_invariance(name, estimator_orig):
722725
raise SkipTest(msg)
723726

724727
if hasattr(estimator, method):
725-
result_full, result_by_batch = _apply_func(
728+
result_full, result_by_batch = _apply_on_subsets(
726729
getattr(estimator, method), X)
727730
assert_allclose(result_full, result_by_batch,
728731
atol=1e-7, err_msg=msg)

sklearn/utils/tests/test_estimator_checks.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,21 @@ def predict(self, X):
161161
return np.zeros(X.shape[0])
162162

163163

164+
class SparseTransformer(BaseEstimator):
165+
def fit(self, X, y=None):
166+
self.X_shape_ = check_array(X).shape
167+
return self
168+
169+
def fit_transform(self, X, y=None):
170+
return self.fit(X, y).transform(X)
171+
172+
def transform(self, X):
173+
X = check_array(X)
174+
if X.shape[1] != self.X_shape_[1]:
175+
raise ValueError('Bad number of features')
176+
return sp.csr_matrix(X)
177+
178+
164179
def test_check_estimator():
165180
# tests that the estimator actually fails on "bad" estimators.
166181
# not a complete test of all checks, which are very extensive.
@@ -235,6 +250,9 @@ def test_check_estimator():
235250
sys.stdout = old_stdout
236251
assert_true(msg in string_buffer.getvalue())
237252

253+
# non-regression test for estimators transforming to sparse data
254+
check_estimator(SparseTransformer())
255+
238256
# doesn't error on actual estimator
239257
check_estimator(AdaBoostClassifier)
240258
check_estimator(AdaBoostClassifier())

0 commit comments

Comments
 (0)