1-
21"""
32The :mod:`sklearn.model_selection._validation` module includes classes and
43functions to validate the model.
@@ -129,6 +128,7 @@ def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None,
129128 X , y , groups = indexable (X , y , groups )
130129
131130 cv = check_cv (cv , y , classifier = is_classifier (estimator ))
131+ cv_iter = list (cv .split (X , y , groups ))
132132 scorer = check_scoring (estimator , scoring = scoring )
133133 # We clone the estimator to make sure that all the folds are
134134 # independent, and that it is pickle-able.
@@ -137,7 +137,7 @@ def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None,
137137 scores = parallel (delayed (_fit_and_score )(clone (estimator ), X , y , scorer ,
138138 train , test , verbose , None ,
139139 fit_params )
140- for train , test in cv . split ( X , y , groups ) )
140+ for train , test in cv_iter )
141141 return np .array (scores )[:, 0 ]
142142
143143
@@ -385,6 +385,7 @@ def cross_val_predict(estimator, X, y=None, groups=None, cv=None, n_jobs=1,
385385 X , y , groups = indexable (X , y , groups )
386386
387387 cv = check_cv (cv , y , classifier = is_classifier (estimator ))
388+ cv_iter = list (cv .split (X , y , groups ))
388389
389390 # Ensure the estimator has implemented the passed decision function
390391 if not callable (getattr (estimator , method )):
@@ -397,7 +398,7 @@ def cross_val_predict(estimator, X, y=None, groups=None, cv=None, n_jobs=1,
397398 pre_dispatch = pre_dispatch )
398399 prediction_blocks = parallel (delayed (_fit_and_predict )(
399400 clone (estimator ), X , y , train , test , verbose , fit_params , method )
400- for train , test in cv . split ( X , y , groups ) )
401+ for train , test in cv_iter )
401402
402403 # Concatenate the predictions
403404 predictions = [pred_block_i for pred_block_i , _ in prediction_blocks ]
@@ -751,9 +752,8 @@ def learning_curve(estimator, X, y, groups=None,
751752 X , y , groups = indexable (X , y , groups )
752753
753754 cv = check_cv (cv , y , classifier = is_classifier (estimator ))
754- cv_iter = cv .split (X , y , groups )
755755 # Make a list since we will be iterating multiple times over the folds
756- cv_iter = list (cv_iter )
756+ cv_iter = list (cv . split ( X , y , groups ) )
757757 scorer = check_scoring (estimator , scoring = scoring )
758758
759759 n_max_training_samples = len (cv_iter [0 ][0 ])
@@ -776,9 +776,8 @@ def learning_curve(estimator, X, y, groups=None,
776776 if exploit_incremental_learning :
777777 classes = np .unique (y ) if is_classifier (estimator ) else None
778778 out = parallel (delayed (_incremental_fit_estimator )(
779- clone (estimator ), X , y , classes , train ,
780- test , train_sizes_abs , scorer , verbose )
781- for train , test in cv_iter )
779+ clone (estimator ), X , y , classes , train , test , train_sizes_abs ,
780+ scorer , verbose ) for train , test in cv_iter )
782781 else :
783782 train_test_proportions = []
784783 for train , test in cv_iter :
@@ -962,6 +961,7 @@ def validation_curve(estimator, X, y, param_name, param_range, groups=None,
962961 X , y , groups = indexable (X , y , groups )
963962
964963 cv = check_cv (cv , y , classifier = is_classifier (estimator ))
964+ cv_iter = list (cv .split (X , y , groups ))
965965
966966 scorer = check_scoring (estimator , scoring = scoring )
967967
@@ -970,7 +970,7 @@ def validation_curve(estimator, X, y, param_name, param_range, groups=None,
970970 out = parallel (delayed (_fit_and_score )(
971971 estimator , X , y , scorer , train , test , verbose ,
972972 parameters = {param_name : v }, fit_params = None , return_train_score = True )
973- for train , test in cv . split ( X , y , groups ) for v in param_range )
973+ for train , test in cv_iter for v in param_range )
974974
975975 out = np .asarray (out )
976976 n_params = len (param_range )
0 commit comments