Skip to content

Commit 38f6a91

Browse files
raghavrvjnothman
authored andcommitted
[MRG + 2] FIX Be robust to non re-entrant/ non deterministic cv.split calls (scikit-learn#7660)
1 parent 6e50c8f commit 38f6a91

File tree

7 files changed

+226
-80
lines changed

7 files changed

+226
-80
lines changed

sklearn/model_selection/_search.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,7 @@ def _fit(self, X, y, groups, parameter_iterable):
550550
base_estimator = clone(self.estimator)
551551
pre_dispatch = self.pre_dispatch
552552

553+
cv_iter = list(cv.split(X, y, groups))
553554
out = Parallel(
554555
n_jobs=self.n_jobs, verbose=self.verbose,
555556
pre_dispatch=pre_dispatch
@@ -561,7 +562,7 @@ def _fit(self, X, y, groups, parameter_iterable):
561562
return_times=True, return_parameters=True,
562563
error_score=self.error_score)
563564
for parameters in parameter_iterable
564-
for train, test in cv.split(X, y, groups))
565+
for train, test in cv_iter)
565566

566567
# if one choose to see train score, "out" will contain train score info
567568
if self.return_train_score:

sklearn/model_selection/_split.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1477,7 +1477,7 @@ def get_n_splits(self, X=None, y=None, groups=None):
14771477
class _CVIterableWrapper(BaseCrossValidator):
14781478
"""Wrapper class for old style cv objects and iterables."""
14791479
def __init__(self, cv):
1480-
self.cv = cv
1480+
self.cv = list(cv)
14811481

14821482
def get_n_splits(self, X=None, y=None, groups=None):
14831483
"""Returns the number of splitting iterations in the cross-validator

sklearn/model_selection/_validation.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
"""
32
The :mod:`sklearn.model_selection._validation` module includes classes and
43
functions to validate the model.
@@ -129,6 +128,7 @@ def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None,
129128
X, y, groups = indexable(X, y, groups)
130129

131130
cv = check_cv(cv, y, classifier=is_classifier(estimator))
131+
cv_iter = list(cv.split(X, y, groups))
132132
scorer = check_scoring(estimator, scoring=scoring)
133133
# We clone the estimator to make sure that all the folds are
134134
# independent, and that it is pickle-able.
@@ -137,7 +137,7 @@ def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None,
137137
scores = parallel(delayed(_fit_and_score)(clone(estimator), X, y, scorer,
138138
train, test, verbose, None,
139139
fit_params)
140-
for train, test in cv.split(X, y, groups))
140+
for train, test in cv_iter)
141141
return np.array(scores)[:, 0]
142142

143143

@@ -385,6 +385,7 @@ def cross_val_predict(estimator, X, y=None, groups=None, cv=None, n_jobs=1,
385385
X, y, groups = indexable(X, y, groups)
386386

387387
cv = check_cv(cv, y, classifier=is_classifier(estimator))
388+
cv_iter = list(cv.split(X, y, groups))
388389

389390
# Ensure the estimator has implemented the passed decision function
390391
if not callable(getattr(estimator, method)):
@@ -397,7 +398,7 @@ def cross_val_predict(estimator, X, y=None, groups=None, cv=None, n_jobs=1,
397398
pre_dispatch=pre_dispatch)
398399
prediction_blocks = parallel(delayed(_fit_and_predict)(
399400
clone(estimator), X, y, train, test, verbose, fit_params, method)
400-
for train, test in cv.split(X, y, groups))
401+
for train, test in cv_iter)
401402

402403
# Concatenate the predictions
403404
predictions = [pred_block_i for pred_block_i, _ in prediction_blocks]
@@ -751,9 +752,8 @@ def learning_curve(estimator, X, y, groups=None,
751752
X, y, groups = indexable(X, y, groups)
752753

753754
cv = check_cv(cv, y, classifier=is_classifier(estimator))
754-
cv_iter = cv.split(X, y, groups)
755755
# Make a list since we will be iterating multiple times over the folds
756-
cv_iter = list(cv_iter)
756+
cv_iter = list(cv.split(X, y, groups))
757757
scorer = check_scoring(estimator, scoring=scoring)
758758

759759
n_max_training_samples = len(cv_iter[0][0])
@@ -776,9 +776,8 @@ def learning_curve(estimator, X, y, groups=None,
776776
if exploit_incremental_learning:
777777
classes = np.unique(y) if is_classifier(estimator) else None
778778
out = parallel(delayed(_incremental_fit_estimator)(
779-
clone(estimator), X, y, classes, train,
780-
test, train_sizes_abs, scorer, verbose)
781-
for train, test in cv_iter)
779+
clone(estimator), X, y, classes, train, test, train_sizes_abs,
780+
scorer, verbose) for train, test in cv_iter)
782781
else:
783782
train_test_proportions = []
784783
for train, test in cv_iter:
@@ -962,6 +961,7 @@ def validation_curve(estimator, X, y, param_name, param_range, groups=None,
962961
X, y, groups = indexable(X, y, groups)
963962

964963
cv = check_cv(cv, y, classifier=is_classifier(estimator))
964+
cv_iter = list(cv.split(X, y, groups))
965965

966966
scorer = check_scoring(estimator, scoring=scoring)
967967

@@ -970,7 +970,7 @@ def validation_curve(estimator, X, y, param_name, param_range, groups=None,
970970
out = parallel(delayed(_fit_and_score)(
971971
estimator, X, y, scorer, train, test, verbose,
972972
parameters={param_name: v}, fit_params=None, return_train_score=True)
973-
for train, test in cv.split(X, y, groups) for v in param_range)
973+
for train, test in cv_iter for v in param_range)
974974

975975
out = np.asarray(out)
976976
n_params = len(param_range)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""
2+
Common utilities for testing model selection.
3+
"""
4+
5+
import numpy as np
6+
7+
from sklearn.model_selection import KFold
8+
9+
10+
class OneTimeSplitter:
11+
"""A wrapper to make KFold single entry cv iterator"""
12+
def __init__(self, n_splits=4, n_samples=99):
13+
self.n_splits = n_splits
14+
self.n_samples = n_samples
15+
self.indices = iter(KFold(n_splits=n_splits).split(np.ones(n_samples)))
16+
17+
def split(self, X=None, y=None, groups=None):
18+
"""Split can be called only once"""
19+
for index in self.indices:
20+
yield index
21+
22+
def get_n_splits(self, X=None, y=None, groups=None):
23+
return self.n_splits

sklearn/model_selection/tests/test_search.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060
from sklearn.pipeline import Pipeline
6161
from sklearn.linear_model import SGDClassifier
6262

63+
from sklearn.model_selection.tests.common import OneTimeSplitter
64+
6365

6466
# Neither of the following two estimators inherit from BaseEstimator,
6567
# to test hyperparameter search on user-defined classifiers.
@@ -1154,3 +1156,58 @@ def test_search_train_scores_set_to_false():
11541156
gs = GridSearchCV(clf, param_grid={'C': [0.1, 0.2]},
11551157
return_train_score=False)
11561158
gs.fit(X, y)
1159+
1160+
1161+
def test_grid_search_cv_splits_consistency():
1162+
# Check if a one time iterable is accepted as a cv parameter.
1163+
n_samples = 100
1164+
n_splits = 5
1165+
X, y = make_classification(n_samples=n_samples, random_state=0)
1166+
1167+
gs = GridSearchCV(LinearSVC(random_state=0),
1168+
param_grid={'C': [0.1, 0.2, 0.3]},
1169+
cv=OneTimeSplitter(n_splits=n_splits,
1170+
n_samples=n_samples))
1171+
gs.fit(X, y)
1172+
1173+
gs2 = GridSearchCV(LinearSVC(random_state=0),
1174+
param_grid={'C': [0.1, 0.2, 0.3]},
1175+
cv=KFold(n_splits=n_splits))
1176+
gs2.fit(X, y)
1177+
1178+
def _pop_time_keys(cv_results):
1179+
for key in ('mean_fit_time', 'std_fit_time',
1180+
'mean_score_time', 'std_score_time'):
1181+
cv_results.pop(key)
1182+
return cv_results
1183+
1184+
# OneTimeSplitter is a non-re-entrant cv where split can be called only
1185+
# once if ``cv.split`` is called once per param setting in GridSearchCV.fit
1186+
# the 2nd and 3rd parameter will not be evaluated as no train/test indices
1187+
# will be generated for the 2nd and subsequent cv.split calls.
1188+
# This is a check to make sure cv.split is not called once per param
1189+
# setting.
1190+
np.testing.assert_equal(_pop_time_keys(gs.cv_results_),
1191+
_pop_time_keys(gs2.cv_results_))
1192+
1193+
# Check consistency of folds across the parameters
1194+
gs = GridSearchCV(LinearSVC(random_state=0),
1195+
param_grid={'C': [0.1, 0.1, 0.2, 0.2]},
1196+
cv=KFold(n_splits=n_splits, shuffle=True))
1197+
gs.fit(X, y)
1198+
1199+
# As the first two param settings (C=0.1) and the next two param
1200+
# settings (C=0.2) are same, the test and train scores must also be
1201+
# same as long as the same train/test indices are generated for all
1202+
# the cv splits, for both param setting
1203+
for score_type in ('train', 'test'):
1204+
per_param_scores = {}
1205+
for param_i in range(4):
1206+
per_param_scores[param_i] = list(
1207+
gs.cv_results_['split%d_%s_score' % (s, score_type)][param_i]
1208+
for s in range(5))
1209+
1210+
assert_array_almost_equal(per_param_scores[0],
1211+
per_param_scores[1])
1212+
assert_array_almost_equal(per_param_scores[2],
1213+
per_param_scores[3])

sklearn/model_selection/tests/test_split.py

Lines changed: 16 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -59,73 +59,9 @@
5959

6060
X = np.ones(10)
6161
y = np.arange(10) // 2
62-
P_sparse = coo_matrix(np.eye(5))
6362
digits = load_digits()
6463

6564

66-
class MockClassifier(object):
67-
"""Dummy classifier to test the cross-validation"""
68-
69-
def __init__(self, a=0, allow_nd=False):
70-
self.a = a
71-
self.allow_nd = allow_nd
72-
73-
def fit(self, X, Y=None, sample_weight=None, class_prior=None,
74-
sparse_sample_weight=None, sparse_param=None, dummy_int=None,
75-
dummy_str=None, dummy_obj=None, callback=None):
76-
"""The dummy arguments are to test that this fit function can
77-
accept non-array arguments through cross-validation, such as:
78-
- int
79-
- str (this is actually array-like)
80-
- object
81-
- function
82-
"""
83-
self.dummy_int = dummy_int
84-
self.dummy_str = dummy_str
85-
self.dummy_obj = dummy_obj
86-
if callback is not None:
87-
callback(self)
88-
89-
if self.allow_nd:
90-
X = X.reshape(len(X), -1)
91-
if X.ndim >= 3 and not self.allow_nd:
92-
raise ValueError('X cannot be d')
93-
if sample_weight is not None:
94-
assert_true(sample_weight.shape[0] == X.shape[0],
95-
'MockClassifier extra fit_param sample_weight.shape[0]'
96-
' is {0}, should be {1}'.format(sample_weight.shape[0],
97-
X.shape[0]))
98-
if class_prior is not None:
99-
assert_true(class_prior.shape[0] == len(np.unique(y)),
100-
'MockClassifier extra fit_param class_prior.shape[0]'
101-
' is {0}, should be {1}'.format(class_prior.shape[0],
102-
len(np.unique(y))))
103-
if sparse_sample_weight is not None:
104-
fmt = ('MockClassifier extra fit_param sparse_sample_weight'
105-
'.shape[0] is {0}, should be {1}')
106-
assert_true(sparse_sample_weight.shape[0] == X.shape[0],
107-
fmt.format(sparse_sample_weight.shape[0], X.shape[0]))
108-
if sparse_param is not None:
109-
fmt = ('MockClassifier extra fit_param sparse_param.shape '
110-
'is ({0}, {1}), should be ({2}, {3})')
111-
assert_true(sparse_param.shape == P_sparse.shape,
112-
fmt.format(sparse_param.shape[0],
113-
sparse_param.shape[1],
114-
P_sparse.shape[0], P_sparse.shape[1]))
115-
return self
116-
117-
def predict(self, T):
118-
if self.allow_nd:
119-
T = T.reshape(len(T), -1)
120-
return T[:, 0]
121-
122-
def score(self, X=None, Y=None):
123-
return 1. / (1 + np.abs(self.a))
124-
125-
def get_params(self, deep=False):
126-
return {'a': self.a, 'allow_nd': self.allow_nd}
127-
128-
12965
@ignore_warnings
13066
def test_cross_validator_with_default_params():
13167
n_samples = 4
@@ -933,6 +869,22 @@ def test_cv_iterable_wrapper():
933869
# Check if get_n_splits works correctly
934870
assert_equal(len(cv), wrapped_old_skf.get_n_splits())
935871

872+
kf_iter = KFold(n_splits=5).split(X, y)
873+
kf_iter_wrapped = check_cv(kf_iter)
874+
# Since the wrapped iterable is enlisted and stored,
875+
# split can be called any number of times to produce
876+
# consistent results.
877+
assert_array_equal(list(kf_iter_wrapped.split(X, y)),
878+
list(kf_iter_wrapped.split(X, y)))
879+
# If the splits are randomized, successive calls to split yields different
880+
# results
881+
kf_randomized_iter = KFold(n_splits=5, shuffle=True).split(X, y)
882+
kf_randomized_iter_wrapped = check_cv(kf_randomized_iter)
883+
assert_array_equal(list(kf_randomized_iter_wrapped.split(X, y)),
884+
list(kf_randomized_iter_wrapped.split(X, y)))
885+
assert_true(np.any(np.array(list(kf_iter_wrapped.split(X, y))) !=
886+
np.array(list(kf_randomized_iter_wrapped.split(X, y)))))
887+
936888

937889
def test_group_kfold():
938890
rng = np.random.RandomState(0)

0 commit comments

Comments
 (0)