Skip to content

Commit db30f6c

Browse files
authored
MAINT Add parameter validation for LabelPropagation and LabelSpreading (scikit-learn#24211)
1 parent dff6081 commit db30f6c

File tree

3 files changed

+21
-28
lines changed

3 files changed

+21
-28
lines changed

sklearn/semi_supervised/_label_propagation.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
# Utkarsh Upadhyay <[email protected]>
5757
# License: BSD
5858
from abc import ABCMeta, abstractmethod
59+
from numbers import Integral, Real
5960

6061
import warnings
6162
import numpy as np
@@ -68,6 +69,7 @@
6869
from ..utils.extmath import safe_sparse_dot
6970
from ..utils.multiclass import check_classification_targets
7071
from ..utils.validation import check_is_fitted
72+
from ..utils._param_validation import Interval, StrOptions
7173
from ..exceptions import ConvergenceWarning
7274

7375

@@ -105,6 +107,16 @@ class BaseLabelPropagation(ClassifierMixin, BaseEstimator, metaclass=ABCMeta):
105107
for more details.
106108
"""
107109

110+
_parameter_constraints: dict = {
111+
"kernel": [StrOptions({"knn", "rbf"}), callable],
112+
"gamma": [Interval(Real, 0, None, closed="left")],
113+
"n_neighbors": [Interval(Integral, 0, None, closed="neither")],
114+
"alpha": [None, Interval(Real, 0, 1, closed="neither")],
115+
"max_iter": [Interval(Integral, 0, None, closed="neither")],
116+
"tol": [Interval(Real, 0, None, closed="left")],
117+
"n_jobs": [None, Integral],
118+
}
119+
108120
def __init__(
109121
self,
110122
kernel="rbf",
@@ -152,13 +164,6 @@ def _get_kernel(self, X, y=None):
152164
return self.kernel(X, X)
153165
else:
154166
return self.kernel(X, y)
155-
else:
156-
raise ValueError(
157-
"%s is not a valid kernel. Only rbf and knn"
158-
" or an explicit function "
159-
" are supported at this time."
160-
% self.kernel
161-
)
162167

163168
@abstractmethod
164169
def _build_graph(self):
@@ -246,6 +251,7 @@ def fit(self, X, y):
246251
self : object
247252
Returns the instance itself.
248253
"""
254+
self._validate_params()
249255
X, y = self._validate_data(X, y)
250256
self.X_ = X
251257
check_classification_targets(y)
@@ -261,14 +267,6 @@ def fit(self, X, y):
261267

262268
n_samples, n_classes = len(y), len(classes)
263269

264-
alpha = self.alpha
265-
if self._variant == "spreading" and (
266-
alpha is None or alpha <= 0.0 or alpha >= 1.0
267-
):
268-
raise ValueError(
269-
"alpha=%s is invalid: it must be inside the open interval (0, 1)"
270-
% alpha
271-
)
272270
y = np.asarray(y)
273271
unlabeled = y == -1
274272

@@ -283,7 +281,7 @@ def fit(self, X, y):
283281
y_static[unlabeled] = 0
284282
else:
285283
# LabelSpreading
286-
y_static *= 1 - alpha
284+
y_static *= 1 - self.alpha
287285

288286
l_previous = np.zeros((self.X_.shape[0], n_classes))
289287

@@ -310,7 +308,7 @@ def fit(self, X, y):
310308
else:
311309
# clamp
312310
self.label_distributions_ = (
313-
np.multiply(alpha, self.label_distributions_) + y_static
311+
np.multiply(self.alpha, self.label_distributions_) + y_static
314312
)
315313
else:
316314
warnings.warn(
@@ -417,6 +415,9 @@ class LabelPropagation(BaseLabelPropagation):
417415

418416
_variant = "propagation"
419417

418+
_parameter_constraints: dict = {**BaseLabelPropagation._parameter_constraints}
419+
_parameter_constraints.pop("alpha")
420+
420421
def __init__(
421422
self,
422423
kernel="rbf",
@@ -573,6 +574,9 @@ class LabelSpreading(BaseLabelPropagation):
573574

574575
_variant = "spreading"
575576

577+
_parameter_constraints: dict = {**BaseLabelPropagation._parameter_constraints}
578+
_parameter_constraints["alpha"] = [Interval(Real, 0, 1, closed="neither")]
579+
576580
def __init__(
577581
self,
578582
kernel="rbf",

sklearn/semi_supervised/tests/test_label_propagation.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -122,15 +122,6 @@ def test_label_propagation_closed_form(global_dtype):
122122
assert_allclose(expected, clf.label_distributions_, atol=1e-4)
123123

124124

125-
@pytest.mark.parametrize("alpha", [-0.1, 0, 1, 1.1, None])
126-
def test_valid_alpha(global_dtype, alpha):
127-
n_classes = 2
128-
X, y = make_classification(n_classes=n_classes, n_samples=200, random_state=0)
129-
X = X.astype(global_dtype)
130-
with pytest.raises(ValueError):
131-
label_propagation.LabelSpreading(alpha=alpha).fit(X, y)
132-
133-
134125
def test_convergence_speed():
135126
# This is a non-regression test for #5774
136127
X = np.array([[1.0, 0.0], [0.0, 1.0], [1.0, 2.5]])

sklearn/tests/test_common.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -469,8 +469,6 @@ def test_estimators_do_not_raise_errors_in_init_or_set_params(Estimator):
469469
"DictionaryLearning",
470470
"HashingVectorizer",
471471
"IterativeImputer",
472-
"LabelPropagation",
473-
"LabelSpreading",
474472
"LatentDirichletAllocation",
475473
"MiniBatchDictionaryLearning",
476474
"MultiTaskElasticNet",

0 commit comments

Comments
 (0)