Skip to content

Commit b4d7c96

Browse files
thomasjpfanjnothman
authored andcommitted
[MRG] Fixes out of bound check in *SearchCV for callable refit (scikit-learn#13417)
1 parent 7a636f0 commit b4d7c96

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

sklearn/model_selection/_search.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -698,7 +698,8 @@ def evaluate_candidates(candidate_params):
698698
self.best_index_ = self.refit(results)
699699
if not isinstance(self.best_index_, (int, np.integer)):
700700
raise TypeError('best_index_ returned is not an integer')
701-
if self.best_index_ < 0 or self.best_index_ >= len(results):
701+
if (self.best_index_ < 0 or
702+
self.best_index_ >= len(results["params"])):
702703
raise IndexError('best_index_ index out of range')
703704
else:
704705
self.best_index_ = results["rank_test_%s"

sklearn/model_selection/tests/test_search.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,9 @@ def refit_callable_invalid_type(cv_results):
651651
clf.fit(X, y)
652652

653653

654-
def test_refit_callable_out_bound():
654+
@pytest.mark.parametrize('out_bound_value', [-1, 2])
655+
@pytest.mark.parametrize('search_cv', [RandomizedSearchCV, GridSearchCV])
656+
def test_refit_callable_out_bound(out_bound_value, search_cv):
655657
"""
656658
Test implementation catches the errors when 'best_index_' returns an
657659
out of bound result.
@@ -660,14 +662,13 @@ def refit_callable_out_bound(cv_results):
660662
"""
661663
A dummy function tests when returned 'best_index_' is out of bounds.
662664
"""
663-
return -1
665+
return out_bound_value
664666

665667
X, y = make_classification(n_samples=100, n_features=4,
666668
random_state=42)
667669

668-
clf = GridSearchCV(LinearSVC(random_state=42), {'C': [0.1, 1]},
669-
scoring='precision', refit=refit_callable_out_bound,
670-
cv=5)
670+
clf = search_cv(LinearSVC(random_state=42), {'C': [0.1, 1]},
671+
scoring='precision', refit=refit_callable_out_bound, cv=5)
671672
with pytest.raises(IndexError, match='best_index_ index out of range'):
672673
clf.fit(X, y)
673674

0 commit comments

Comments
 (0)