Skip to content

Commit 7211ee4

Browse files
authored
MNT Clean up deprecations for 1.8: fit_params (scikit-learn#32521)
1 parent 597646a commit 7211ee4

File tree

3 files changed

+9
-91
lines changed

3 files changed

+9
-91
lines changed

sklearn/model_selection/_plot.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ def from_estimator(
488488
random_state=random_state,
489489
error_score=error_score,
490490
return_times=False,
491-
fit_params=fit_params,
491+
params=fit_params,
492492
)
493493

494494
viz = cls(
@@ -864,7 +864,7 @@ def from_estimator(
864864
pre_dispatch=pre_dispatch,
865865
verbose=verbose,
866866
error_score=error_score,
867-
fit_params=fit_params,
867+
params=fit_params,
868868
)
869869

870870
viz = cls(

sklearn/model_selection/_validation.py

Lines changed: 7 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -54,35 +54,6 @@
5454
]
5555

5656

57-
def _check_params_groups_deprecation(fit_params, params, groups, version):
58-
"""A helper function to check deprecations on `groups` and `fit_params`.
59-
60-
# TODO(SLEP6): To be removed when set_config(enable_metadata_routing=False) is not
61-
# possible.
62-
"""
63-
if params is not None and fit_params is not None:
64-
raise ValueError(
65-
"`params` and `fit_params` cannot both be provided. Pass parameters "
66-
"via `params`. `fit_params` is deprecated and will be removed in "
67-
f"version {version}."
68-
)
69-
elif fit_params is not None:
70-
warnings.warn(
71-
(
72-
"`fit_params` is deprecated and will be removed in version {version}. "
73-
"Pass parameters via `params` instead."
74-
),
75-
FutureWarning,
76-
)
77-
params = fit_params
78-
79-
params = {} if params is None else params
80-
81-
_check_groups_routing_disabled(groups)
82-
83-
return params
84-
85-
8657
# TODO(SLEP6): To be removed when set_config(enable_metadata_routing=False) is not
8758
# possible.
8859
def _check_groups_routing_disabled(groups):
@@ -1446,7 +1417,6 @@ def _check_is_permutation(indices, n_samples):
14461417
"random_state": ["random_state"],
14471418
"verbose": ["verbose"],
14481419
"scoring": [StrOptions(set(get_scorer_names())), callable, None],
1449-
"fit_params": [dict, None],
14501420
"params": [dict, None],
14511421
},
14521422
prefer_skip_nested_validation=False, # estimator is not validated yet
@@ -1463,7 +1433,6 @@ def permutation_test_score(
14631433
random_state=0,
14641434
verbose=0,
14651435
scoring=None,
1466-
fit_params=None,
14671436
params=None,
14681437
):
14691438
"""Evaluate the significance of a cross-validated score with permutations.
@@ -1558,13 +1527,6 @@ def permutation_test_score(
15581527
- `None`: the `estimator`'s
15591528
:ref:`default evaluation criterion <scoring_api_overview>` is used.
15601529
1561-
fit_params : dict, default=None
1562-
Parameters to pass to the fit method of the estimator.
1563-
1564-
.. deprecated:: 1.6
1565-
This parameter is deprecated and will be removed in version 1.6. Use
1566-
``params`` instead.
1567-
15681530
params : dict, default=None
15691531
Parameters to pass to the `fit` method of the estimator, the scorer
15701532
and the cv splitter.
@@ -1624,7 +1586,8 @@ def permutation_test_score(
16241586
>>> print(f"P-value: {pvalue:.3f}")
16251587
P-value: 0.010
16261588
"""
1627-
params = _check_params_groups_deprecation(fit_params, params, groups, "1.8")
1589+
_check_groups_routing_disabled(groups)
1590+
params = {} if params is None else params
16281591

16291592
X, y, groups = indexable(X, y, groups)
16301593

@@ -1750,7 +1713,6 @@ def _shuffle(y, groups, random_state):
17501713
"random_state": ["random_state"],
17511714
"error_score": [StrOptions({"raise"}), Real],
17521715
"return_times": ["boolean"],
1753-
"fit_params": [dict, None],
17541716
"params": [dict, None],
17551717
},
17561718
prefer_skip_nested_validation=False, # estimator is not validated yet
@@ -1772,7 +1734,6 @@ def learning_curve(
17721734
random_state=None,
17731735
error_score=np.nan,
17741736
return_times=False,
1775-
fit_params=None,
17761737
params=None,
17771738
):
17781739
"""Learning curve.
@@ -1892,13 +1853,6 @@ def learning_curve(
18921853
return_times : bool, default=False
18931854
Whether to return the fit and score times.
18941855
1895-
fit_params : dict, default=None
1896-
Parameters to pass to the fit method of the estimator.
1897-
1898-
.. deprecated:: 1.6
1899-
This parameter is deprecated and will be removed in version 1.8. Use
1900-
``params`` instead.
1901-
19021856
params : dict, default=None
19031857
Parameters to pass to the `fit` method of the estimator and to the scorer.
19041858
@@ -1968,8 +1922,8 @@ def learning_curve(
19681922
"An estimator must support the partial_fit interface "
19691923
"to exploit incremental learning"
19701924
)
1971-
1972-
params = _check_params_groups_deprecation(fit_params, params, groups, "1.8")
1925+
_check_groups_routing_disabled(groups)
1926+
params = {} if params is None else params
19731927

19741928
X, y, groups = indexable(X, y, groups)
19751929

@@ -2254,7 +2208,6 @@ def _incremental_fit_estimator(
22542208
"pre_dispatch": [Integral, str],
22552209
"verbose": ["verbose"],
22562210
"error_score": [StrOptions({"raise"}), Real],
2257-
"fit_params": [dict, None],
22582211
"params": [dict, None],
22592212
},
22602213
prefer_skip_nested_validation=False, # estimator is not validated yet
@@ -2273,7 +2226,6 @@ def validation_curve(
22732226
pre_dispatch="all",
22742227
verbose=0,
22752228
error_score=np.nan,
2276-
fit_params=None,
22772229
params=None,
22782230
):
22792231
"""Validation curve.
@@ -2372,13 +2324,6 @@ def validation_curve(
23722324
23732325
.. versionadded:: 0.20
23742326
2375-
fit_params : dict, default=None
2376-
Parameters to pass to the fit method of the estimator.
2377-
2378-
.. deprecated:: 1.6
2379-
This parameter is deprecated and will be removed in version 1.8. Use
2380-
``params`` instead.
2381-
23822327
params : dict, default=None
23832328
Parameters to pass to the estimator, scorer and cross-validation object.
23842329
@@ -2425,7 +2370,9 @@ def validation_curve(
24252370
>>> print(f"The average test accuracy is {test_scores.mean():.2f}")
24262371
The average test accuracy is 0.81
24272372
"""
2428-
params = _check_params_groups_deprecation(fit_params, params, groups, "1.8")
2373+
_check_groups_routing_disabled(groups)
2374+
params = {} if params is None else params
2375+
24292376
X, y, groups = indexable(X, y, groups)
24302377

24312378
cv = check_cv(cv, y, classifier=is_classifier(estimator))

sklearn/model_selection/tests/test_validation.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2464,35 +2464,6 @@ def test_cross_validate_return_indices(global_random_seed):
24642464
# ======================================================
24652465

24662466

2467-
# TODO(1.8): remove `learning_curve`, `validation_curve` and `permutation_test_score`.
2468-
@pytest.mark.parametrize(
2469-
"func, extra_args",
2470-
[
2471-
(learning_curve, {}),
2472-
(permutation_test_score, {}),
2473-
(validation_curve, {"param_name": "alpha", "param_range": np.array([1])}),
2474-
],
2475-
)
2476-
def test_fit_param_deprecation(func, extra_args):
2477-
"""Check that we warn about deprecating `fit_params`."""
2478-
with pytest.warns(FutureWarning, match="`fit_params` is deprecated"):
2479-
func(
2480-
estimator=ConsumingClassifier(), X=X, y=y, cv=2, fit_params={}, **extra_args
2481-
)
2482-
2483-
with pytest.raises(
2484-
ValueError, match="`params` and `fit_params` cannot both be provided"
2485-
):
2486-
func(
2487-
estimator=ConsumingClassifier(),
2488-
X=X,
2489-
y=y,
2490-
fit_params={},
2491-
params={},
2492-
**extra_args,
2493-
)
2494-
2495-
24962467
@pytest.mark.parametrize(
24972468
"func, extra_args",
24982469
[

0 commit comments

Comments
 (0)