Skip to content

Commit cf6c122

Browse files
MAINT Add parameter validation to SelectFromModel (scikit-learn#24213)
Co-authored-by: jeremie du boisberranger <[email protected]>
1 parent 7c835d5 commit cf6c122

File tree

3 files changed

+43
-41
lines changed

3 files changed

+43
-41
lines changed

sklearn/feature_selection/_from_model.py

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44
from copy import deepcopy
55

66
import numpy as np
7-
import numbers
7+
from numbers import Integral, Real
88

99
from ._base import SelectorMixin
1010
from ._base import _get_feature_importances
1111
from ..base import BaseEstimator, clone, MetaEstimatorMixin
1212
from ..utils._tags import _safe_tags
1313
from ..utils.validation import check_is_fitted, check_scalar, _num_features
14+
from ..utils._param_validation import HasMethods, Interval, Options
1415

1516
from ..exceptions import NotFittedError
1617
from ..utils.metaestimators import available_if
@@ -229,6 +230,19 @@ class SelectFromModel(MetaEstimatorMixin, SelectorMixin, BaseEstimator):
229230
2
230231
"""
231232

233+
_parameter_constraints: dict = {
234+
"estimator": [HasMethods("fit")],
235+
"threshold": [Interval(Real, None, None, closed="both"), str, None],
236+
"prefit": ["boolean"],
237+
"norm_order": [
238+
Interval(Integral, None, -1, closed="right"),
239+
Interval(Integral, 1, None, closed="left"),
240+
Options(Real, {np.inf, -np.inf}),
241+
],
242+
"max_features": [Interval(Integral, 0, None, closed="left"), callable, None],
243+
"importance_getter": [str, callable],
244+
}
245+
232246
def __init__(
233247
self,
234248
estimator,
@@ -266,9 +280,7 @@ def _get_support_mask(self):
266280
"When `prefit=True` and `max_features` is a callable, call `fit` "
267281
"before calling `transform`."
268282
)
269-
elif max_features is not None and not isinstance(
270-
max_features, numbers.Integral
271-
):
283+
elif max_features is not None and not isinstance(max_features, Integral):
272284
raise ValueError(
273285
f"`max_features` must be an integer. Got `max_features={max_features}` "
274286
"instead."
@@ -294,30 +306,19 @@ def _check_max_features(self, X):
294306
if self.max_features is not None:
295307
n_features = _num_features(X)
296308

297-
if isinstance(self.max_features, numbers.Integral):
298-
check_scalar(
299-
self.max_features,
300-
"max_features",
301-
numbers.Integral,
302-
min_val=0,
303-
max_val=n_features,
304-
)
305-
self.max_features_ = self.max_features
306-
elif callable(self.max_features):
309+
if callable(self.max_features):
307310
max_features = self.max_features(X)
308-
check_scalar(
309-
max_features,
310-
"max_features(X)",
311-
numbers.Integral,
312-
min_val=0,
313-
max_val=n_features,
314-
)
315-
self.max_features_ = max_features
316-
else:
317-
raise TypeError(
318-
"'max_features' must be either an int or a callable that takes"
319-
f" 'X' as input. Got {self.max_features} instead."
320-
)
311+
else: # int
312+
max_features = self.max_features
313+
314+
check_scalar(
315+
max_features,
316+
"max_features",
317+
Integral,
318+
min_val=0,
319+
max_val=n_features,
320+
)
321+
self.max_features_ = max_features
321322

322323
def fit(self, X, y=None, **fit_params):
323324
"""Fit the SelectFromModel meta-transformer.
@@ -339,6 +340,7 @@ def fit(self, X, y=None, **fit_params):
339340
self : object
340341
Fitted estimator.
341342
"""
343+
self._validate_params()
342344
self._check_max_features(X)
343345

344346
if self.prefit:
@@ -393,10 +395,14 @@ def partial_fit(self, X, y=None, **fit_params):
393395
self : object
394396
Fitted estimator.
395397
"""
396-
self._check_max_features(X)
398+
first_call = not hasattr(self, "estimator_")
399+
400+
if first_call:
401+
self._validate_params()
402+
self._check_max_features(X)
397403

398404
if self.prefit:
399-
if not hasattr(self, "estimator_"):
405+
if first_call:
400406
try:
401407
check_is_fitted(self.estimator)
402408
except NotFittedError as exc:
@@ -407,7 +413,6 @@ def partial_fit(self, X, y=None, **fit_params):
407413
self.estimator_ = deepcopy(self.estimator)
408414
return self
409415

410-
first_call = not hasattr(self, "estimator_")
411416
if first_call:
412417
self.estimator_ = clone(self.estimator)
413418
self.estimator_.partial_fit(X, y, **fit_params)

sklearn/feature_selection/tests/test_from_model.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ def test_input_estimator_unchanged():
7373
@pytest.mark.parametrize(
7474
"max_features, err_type, err_msg",
7575
[
76-
(-1, ValueError, "max_features =="),
7776
(
7877
data.shape[1] + 1,
7978
ValueError,
@@ -82,17 +81,17 @@ def test_input_estimator_unchanged():
8281
(
8382
lambda X: 1.5,
8483
TypeError,
85-
"max_features(X) must be an instance of int, not float.",
84+
"max_features must be an instance of int, not float.",
8685
),
8786
(
88-
"gobbledigook",
89-
TypeError,
90-
"'max_features' must be either an int or a callable",
87+
lambda X: data.shape[1] + 1,
88+
ValueError,
89+
"max_features ==",
9190
),
9291
(
93-
"all",
94-
TypeError,
95-
"'max_features' must be either an int or a callable",
92+
lambda X: -1,
93+
ValueError,
94+
"max_features ==",
9695
),
9796
],
9897
)
@@ -629,8 +628,7 @@ def importance_getter(estimator):
629628
"error, err_msg, max_features",
630629
(
631630
[ValueError, "max_features == 10, must be <= 4", 10],
632-
[TypeError, "'max_features' must be either an int or a callable", "a"],
633-
[ValueError, r"max_features\(X\) == 5, must be <= 4", lambda x: x.shape[1] + 1],
631+
[ValueError, "max_features == 5, must be <= 4", lambda x: x.shape[1] + 1],
634632
),
635633
)
636634
def test_partial_fit_validate_max_features(error, err_msg, max_features):

sklearn/tests/test_common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,6 @@ def test_estimators_do_not_raise_errors_in_init_or_set_params(Estimator):
477477
"RANSACRegressor",
478478
"RidgeCV",
479479
"RidgeClassifierCV",
480-
"SelectFromModel",
481480
]
482481

483482

0 commit comments

Comments
 (0)