5656# Utkarsh Upadhyay <[email protected] > 5757# License: BSD
5858from abc import ABCMeta , abstractmethod
59+ from numbers import Integral , Real
5960
6061import warnings
6162import numpy as np
6869from ..utils .extmath import safe_sparse_dot
6970from ..utils .multiclass import check_classification_targets
7071from ..utils .validation import check_is_fitted
72+ from ..utils ._param_validation import Interval , StrOptions
7173from ..exceptions import ConvergenceWarning
7274
7375
@@ -105,6 +107,16 @@ class BaseLabelPropagation(ClassifierMixin, BaseEstimator, metaclass=ABCMeta):
105107 for more details.
106108 """
107109
110+ _parameter_constraints : dict = {
111+ "kernel" : [StrOptions ({"knn" , "rbf" }), callable ],
112+ "gamma" : [Interval (Real , 0 , None , closed = "left" )],
113+ "n_neighbors" : [Interval (Integral , 0 , None , closed = "neither" )],
114+ "alpha" : [None , Interval (Real , 0 , 1 , closed = "neither" )],
115+ "max_iter" : [Interval (Integral , 0 , None , closed = "neither" )],
116+ "tol" : [Interval (Real , 0 , None , closed = "left" )],
117+ "n_jobs" : [None , Integral ],
118+ }
119+
108120 def __init__ (
109121 self ,
110122 kernel = "rbf" ,
@@ -152,13 +164,6 @@ def _get_kernel(self, X, y=None):
152164 return self .kernel (X , X )
153165 else :
154166 return self .kernel (X , y )
155- else :
156- raise ValueError (
157- "%s is not a valid kernel. Only rbf and knn"
158- " or an explicit function "
159- " are supported at this time."
160- % self .kernel
161- )
162167
163168 @abstractmethod
164169 def _build_graph (self ):
@@ -246,6 +251,7 @@ def fit(self, X, y):
246251 self : object
247252 Returns the instance itself.
248253 """
254+ self ._validate_params ()
249255 X , y = self ._validate_data (X , y )
250256 self .X_ = X
251257 check_classification_targets (y )
@@ -261,14 +267,6 @@ def fit(self, X, y):
261267
262268 n_samples , n_classes = len (y ), len (classes )
263269
264- alpha = self .alpha
265- if self ._variant == "spreading" and (
266- alpha is None or alpha <= 0.0 or alpha >= 1.0
267- ):
268- raise ValueError (
269- "alpha=%s is invalid: it must be inside the open interval (0, 1)"
270- % alpha
271- )
272270 y = np .asarray (y )
273271 unlabeled = y == - 1
274272
@@ -283,7 +281,7 @@ def fit(self, X, y):
283281 y_static [unlabeled ] = 0
284282 else :
285283 # LabelSpreading
286- y_static *= 1 - alpha
284+ y_static *= 1 - self . alpha
287285
288286 l_previous = np .zeros ((self .X_ .shape [0 ], n_classes ))
289287
@@ -310,7 +308,7 @@ def fit(self, X, y):
310308 else :
311309 # clamp
312310 self .label_distributions_ = (
313- np .multiply (alpha , self .label_distributions_ ) + y_static
311+ np .multiply (self . alpha , self .label_distributions_ ) + y_static
314312 )
315313 else :
316314 warnings .warn (
@@ -417,6 +415,9 @@ class LabelPropagation(BaseLabelPropagation):
417415
418416 _variant = "propagation"
419417
418+ _parameter_constraints : dict = {** BaseLabelPropagation ._parameter_constraints }
419+ _parameter_constraints .pop ("alpha" )
420+
420421 def __init__ (
421422 self ,
422423 kernel = "rbf" ,
@@ -573,6 +574,9 @@ class LabelSpreading(BaseLabelPropagation):
573574
574575 _variant = "spreading"
575576
577+ _parameter_constraints : dict = {** BaseLabelPropagation ._parameter_constraints }
578+ _parameter_constraints ["alpha" ] = [Interval (Real , 0 , 1 , closed = "neither" )]
579+
576580 def __init__ (
577581 self ,
578582 kernel = "rbf" ,
0 commit comments