Skip to content

Commit fb4eb1e

Browse files
MAINT Parameters validation for MultiTaskElasticNet and MultiTaskLasso (scikit-learn#24295)
Co-authored-by: jeremie du boisberranger <[email protected]>
1 parent 5e9fa42 commit fb4eb1e

File tree

3 files changed

+13
-8
lines changed

3 files changed

+13
-8
lines changed

sklearn/linear_model/_coordinate_descent.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2410,6 +2410,12 @@ class MultiTaskElasticNet(Lasso):
24102410
[0.0872422 0.0872422]
24112411
"""
24122412

2413+
_parameter_constraints: dict = {
2414+
**ElasticNet._parameter_constraints,
2415+
}
2416+
for param in ("precompute", "positive"):
2417+
_parameter_constraints.pop(param)
2418+
24132419
def __init__(
24142420
self,
24152421
alpha=1.0,
@@ -2459,6 +2465,8 @@ def fit(self, X, y):
24592465
To avoid memory re-allocation it is advised to allocate the
24602466
initial data in memory directly using that format.
24612467
"""
2468+
self._validate_params()
2469+
24622470
_normalize = _deprecate_normalize(
24632471
self.normalize, default=False, estimator_name=self.__class__.__name__
24642472
)
@@ -2501,8 +2509,6 @@ def fit(self, X, y):
25012509

25022510
self.coef_ = np.asfortranarray(self.coef_) # coef contiguous in memory
25032511

2504-
if self.selection not in ["random", "cyclic"]:
2505-
raise ValueError("selection should be either random or cyclic.")
25062512
random = self.selection == "random"
25072513

25082514
(
@@ -2660,6 +2666,11 @@ class MultiTaskLasso(MultiTaskElasticNet):
26602666
[-0.41888636 -0.87382323]
26612667
"""
26622668

2669+
_parameter_constraints: dict = {
2670+
**MultiTaskElasticNet._parameter_constraints,
2671+
}
2672+
_parameter_constraints.pop("l1_ratio")
2673+
26632674
def __init__(
26642675
self,
26652676
alpha=1.0,

sklearn/linear_model/tests/test_coordinate_descent.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1385,10 +1385,6 @@ def test_convergence_warnings():
13851385
X = random_state.standard_normal((1000, 500))
13861386
y = random_state.standard_normal((1000, 3))
13871387

1388-
# check that the model fails to converge (a negative dual gap cannot occur)
1389-
with pytest.warns(ConvergenceWarning):
1390-
MultiTaskElasticNet(max_iter=1, tol=-1).fit(X, y)
1391-
13921388
# check that the model converges w/o convergence warnings
13931389
with warnings.catch_warnings():
13941390
warnings.simplefilter("error", ConvergenceWarning)

sklearn/tests/test_common.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -466,8 +466,6 @@ def test_estimators_do_not_raise_errors_in_init_or_set_params(Estimator):
466466
PARAM_VALIDATION_ESTIMATORS_TO_IGNORE = [
467467
"DictionaryLearning",
468468
"MiniBatchDictionaryLearning",
469-
"MultiTaskElasticNet",
470-
"MultiTaskLasso",
471469
"Nystroem",
472470
"OAS",
473471
"OPTICS",

0 commit comments

Comments
 (0)