Skip to content

Commit 1e386a4

Browse files
amy12xxogrisel
andauthored
ENH Add support for 'fit_params' to learning_curve (scikit-learn#18595)
Co-authored-by: Olivier Grisel <[email protected]>
1 parent 0881f2c commit 1e386a4

File tree

3 files changed

+80
-8
lines changed

3 files changed

+80
-8
lines changed

doc/whats_new/v0.24.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,10 @@ Changelog
474474
:pr:`18266` by :user:`Subrat Sahu <subrat93>`,
475475
:user:`Nirvan <Nirvan101>` and :user:`Arthur Book <ArthurBook>`.
476476

477+
- |Enhancement| :func:`model_selection.learning_curve` now accept fit_params
478+
to pass additional estimator parameters.
479+
:pr:`18595` by :user:`Amanda Dsouza <amy12xx>`.
480+
477481
:mod:`sklearn.multiclass`
478482
.........................
479483

sklearn/model_selection/_validation.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,8 +1214,8 @@ def learning_curve(estimator, X, y, *, groups=None,
12141214
train_sizes=np.linspace(0.1, 1.0, 5), cv=None,
12151215
scoring=None, exploit_incremental_learning=False,
12161216
n_jobs=None, pre_dispatch="all", verbose=0, shuffle=False,
1217-
random_state=None, error_score=np.nan,
1218-
return_times=False):
1217+
random_state=None, error_score=np.nan, return_times=False,
1218+
fit_params=None):
12191219
"""Learning curve.
12201220
12211221
Determines cross-validated training and test scores for different training
@@ -1319,6 +1319,11 @@ def learning_curve(estimator, X, y, *, groups=None,
13191319
return_times : bool, default=False
13201320
Whether to return the fit and score times.
13211321
1322+
fit_params : dict, default=None
1323+
Parameters to pass to the fit method of the estimator.
1324+
1325+
.. versionadded:: 0.24
1326+
13221327
Returns
13231328
-------
13241329
train_sizes_abs : array of shape (n_unique_ticks,)
@@ -1377,7 +1382,8 @@ def learning_curve(estimator, X, y, *, groups=None,
13771382
classes = np.unique(y) if is_classifier(estimator) else None
13781383
out = parallel(delayed(_incremental_fit_estimator)(
13791384
clone(estimator), X, y, classes, train, test, train_sizes_abs,
1380-
scorer, verbose, return_times, error_score=error_score)
1385+
scorer, verbose, return_times, error_score=error_score,
1386+
fit_params=fit_params)
13811387
for train, test in cv_iter
13821388
)
13831389
out = np.asarray(out).transpose((2, 1, 0))
@@ -1389,7 +1395,7 @@ def learning_curve(estimator, X, y, *, groups=None,
13891395

13901396
results = parallel(delayed(_fit_and_score)(
13911397
clone(estimator), X, y, scorer, train, test, verbose,
1392-
parameters=None, fit_params=None, return_train_score=True,
1398+
parameters=None, fit_params=fit_params, return_train_score=True,
13931399
error_score=error_score, return_times=return_times)
13941400
for train, test in train_test_proportions
13951401
)
@@ -1472,10 +1478,12 @@ def _translate_train_sizes(train_sizes, n_max_training_samples):
14721478

14731479
def _incremental_fit_estimator(estimator, X, y, classes, train, test,
14741480
train_sizes, scorer, verbose,
1475-
return_times, error_score):
1481+
return_times, error_score, fit_params):
14761482
"""Train estimator on training subsets incrementally and compute scores."""
14771483
train_scores, test_scores, fit_times, score_times = [], [], [], []
14781484
partitions = zip(train_sizes, np.split(train, train_sizes)[:-1])
1485+
if fit_params is None:
1486+
fit_params = {}
14791487
for n_train_samples, partial_train in partitions:
14801488
train_subset = train[:n_train_samples]
14811489
X_train, y_train = _safe_split(estimator, X, y, train_subset)
@@ -1484,10 +1492,11 @@ def _incremental_fit_estimator(estimator, X, y, classes, train, test,
14841492
X_test, y_test = _safe_split(estimator, X, y, test, train_subset)
14851493
start_fit = time.time()
14861494
if y_partial_train is None:
1487-
estimator.partial_fit(X_partial_train, classes=classes)
1495+
estimator.partial_fit(X_partial_train, classes=classes,
1496+
**fit_params)
14881497
else:
14891498
estimator.partial_fit(X_partial_train, y_partial_train,
1490-
classes=classes)
1499+
classes=classes, **fit_params)
14911500
fit_time = time.time() - start_fit
14921501
fit_times.append(fit_time)
14931502

sklearn/model_selection/tests/test_validation.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from sklearn.utils._testing import assert_allclose
2626
from sklearn.utils._mocking import CheckingClassifier, MockDataFrame
2727

28+
from sklearn.utils.validation import _num_samples
29+
2830
from sklearn.model_selection import cross_val_score, ShuffleSplit
2931
from sklearn.model_selection import cross_val_predict
3032
from sklearn.model_selection import cross_validate
@@ -114,9 +116,10 @@ def _is_training_data(self, X):
114116

115117
class MockIncrementalImprovingEstimator(MockImprovingEstimator):
116118
"""Dummy classifier that provides partial_fit"""
117-
def __init__(self, n_max_train_sizes):
119+
def __init__(self, n_max_train_sizes, expected_fit_params=None):
118120
super().__init__(n_max_train_sizes)
119121
self.x = None
122+
self.expected_fit_params = expected_fit_params
120123

121124
def _is_training_data(self, X):
122125
return self.x in X
@@ -125,6 +128,20 @@ def partial_fit(self, X, y=None, **params):
125128
self.train_sizes += X.shape[0]
126129
self.x = X[0]
127130

131+
if self.expected_fit_params:
132+
missing = set(self.expected_fit_params) - set(params)
133+
if missing:
134+
raise AssertionError(
135+
f'Expected fit parameter(s) {list(missing)} not seen.'
136+
)
137+
for key, value in params.items():
138+
if key in self.expected_fit_params and \
139+
_num_samples(value) != _num_samples(X):
140+
raise AssertionError(
141+
f'Fit parameter {key} has length {_num_samples(value)}'
142+
f'; expected {_num_samples(X)}.'
143+
)
144+
128145

129146
class MockEstimatorWithParameter(BaseEstimator):
130147
"""Dummy classifier to test the validation curve"""
@@ -1249,6 +1266,48 @@ def test_learning_curve_with_shuffle():
12491266
test_scores_batch.mean(axis=1))
12501267

12511268

1269+
def test_learning_curve_fit_params():
1270+
X = np.arange(100).reshape(10, 10)
1271+
y = np.array([0] * 5 + [1] * 5)
1272+
clf = CheckingClassifier(expected_fit_params=['sample_weight'])
1273+
1274+
err_msg = r"Expected fit parameter\(s\) \['sample_weight'\] not seen."
1275+
with pytest.raises(AssertionError, match=err_msg):
1276+
learning_curve(clf, X, y, error_score='raise')
1277+
1278+
err_msg = "Fit parameter sample_weight has length 1; expected"
1279+
with pytest.raises(AssertionError, match=err_msg):
1280+
learning_curve(clf, X, y, error_score='raise',
1281+
fit_params={'sample_weight': np.ones(1)})
1282+
learning_curve(clf, X, y, error_score='raise',
1283+
fit_params={'sample_weight': np.ones(10)})
1284+
1285+
1286+
def test_learning_curve_incremental_learning_fit_params():
1287+
X, y = make_classification(n_samples=30, n_features=1, n_informative=1,
1288+
n_redundant=0, n_classes=2,
1289+
n_clusters_per_class=1, random_state=0)
1290+
estimator = MockIncrementalImprovingEstimator(20, ['sample_weight'])
1291+
err_msg = r"Expected fit parameter\(s\) \['sample_weight'\] not seen."
1292+
with pytest.raises(AssertionError, match=err_msg):
1293+
learning_curve(estimator, X, y, cv=3,
1294+
exploit_incremental_learning=True,
1295+
train_sizes=np.linspace(0.1, 1.0, 10),
1296+
error_score='raise')
1297+
1298+
err_msg = "Fit parameter sample_weight has length 3; expected"
1299+
with pytest.raises(AssertionError, match=err_msg):
1300+
learning_curve(estimator, X, y, cv=3,
1301+
exploit_incremental_learning=True,
1302+
train_sizes=np.linspace(0.1, 1.0, 10),
1303+
error_score='raise',
1304+
fit_params={'sample_weight': np.ones(3)})
1305+
1306+
learning_curve(estimator, X, y, cv=3, exploit_incremental_learning=True,
1307+
train_sizes=np.linspace(0.1, 1.0, 10), error_score='raise',
1308+
fit_params={'sample_weight': np.ones(2)})
1309+
1310+
12521311
def test_validation_curve():
12531312
X, y = make_classification(n_samples=2, n_features=1, n_informative=1,
12541313
n_redundant=0, n_classes=2,

0 commit comments

Comments
 (0)