44from copy import deepcopy
55
66import numpy as np
7- import numbers
7+ from numbers import Integral , Real
88
99from ._base import SelectorMixin
1010from ._base import _get_feature_importances
1111from ..base import BaseEstimator , clone , MetaEstimatorMixin
1212from ..utils ._tags import _safe_tags
1313from ..utils .validation import check_is_fitted , check_scalar , _num_features
14+ from ..utils ._param_validation import HasMethods , Interval , Options
1415
1516from ..exceptions import NotFittedError
1617from ..utils .metaestimators import available_if
@@ -229,6 +230,19 @@ class SelectFromModel(MetaEstimatorMixin, SelectorMixin, BaseEstimator):
229230 2
230231 """
231232
233+ _parameter_constraints : dict = {
234+ "estimator" : [HasMethods ("fit" )],
235+ "threshold" : [Interval (Real , None , None , closed = "both" ), str , None ],
236+ "prefit" : ["boolean" ],
237+ "norm_order" : [
238+ Interval (Integral , None , - 1 , closed = "right" ),
239+ Interval (Integral , 1 , None , closed = "left" ),
240+ Options (Real , {np .inf , - np .inf }),
241+ ],
242+ "max_features" : [Interval (Integral , 0 , None , closed = "left" ), callable , None ],
243+ "importance_getter" : [str , callable ],
244+ }
245+
232246 def __init__ (
233247 self ,
234248 estimator ,
@@ -266,9 +280,7 @@ def _get_support_mask(self):
266280 "When `prefit=True` and `max_features` is a callable, call `fit` "
267281 "before calling `transform`."
268282 )
269- elif max_features is not None and not isinstance (
270- max_features , numbers .Integral
271- ):
283+ elif max_features is not None and not isinstance (max_features , Integral ):
272284 raise ValueError (
273285 f"`max_features` must be an integer. Got `max_features={ max_features } ` "
274286 "instead."
@@ -294,30 +306,19 @@ def _check_max_features(self, X):
294306 if self .max_features is not None :
295307 n_features = _num_features (X )
296308
297- if isinstance (self .max_features , numbers .Integral ):
298- check_scalar (
299- self .max_features ,
300- "max_features" ,
301- numbers .Integral ,
302- min_val = 0 ,
303- max_val = n_features ,
304- )
305- self .max_features_ = self .max_features
306- elif callable (self .max_features ):
309+ if callable (self .max_features ):
307310 max_features = self .max_features (X )
308- check_scalar (
309- max_features ,
310- "max_features(X)" ,
311- numbers .Integral ,
312- min_val = 0 ,
313- max_val = n_features ,
314- )
315- self .max_features_ = max_features
316- else :
317- raise TypeError (
318- "'max_features' must be either an int or a callable that takes"
319- f" 'X' as input. Got { self .max_features } instead."
320- )
311+ else : # int
312+ max_features = self .max_features
313+
314+ check_scalar (
315+ max_features ,
316+ "max_features" ,
317+ Integral ,
318+ min_val = 0 ,
319+ max_val = n_features ,
320+ )
321+ self .max_features_ = max_features
321322
322323 def fit (self , X , y = None , ** fit_params ):
323324 """Fit the SelectFromModel meta-transformer.
@@ -339,6 +340,7 @@ def fit(self, X, y=None, **fit_params):
339340 self : object
340341 Fitted estimator.
341342 """
343+ self ._validate_params ()
342344 self ._check_max_features (X )
343345
344346 if self .prefit :
@@ -393,10 +395,14 @@ def partial_fit(self, X, y=None, **fit_params):
393395 self : object
394396 Fitted estimator.
395397 """
396- self ._check_max_features (X )
398+ first_call = not hasattr (self , "estimator_" )
399+
400+ if first_call :
401+ self ._validate_params ()
402+ self ._check_max_features (X )
397403
398404 if self .prefit :
399- if not hasattr ( self , "estimator_" ) :
405+ if first_call :
400406 try :
401407 check_is_fitted (self .estimator )
402408 except NotFittedError as exc :
@@ -407,7 +413,6 @@ def partial_fit(self, X, y=None, **fit_params):
407413 self .estimator_ = deepcopy (self .estimator )
408414 return self
409415
410- first_call = not hasattr (self , "estimator_" )
411416 if first_call :
412417 self .estimator_ = clone (self .estimator )
413418 self .estimator_ .partial_fit (X , y , ** fit_params )
0 commit comments