Skip to content

Commit 06c8996

Browse files
committed
Support of the collections.Sequence type has been added to the _check_param_grid method from model_selection.
1 parent 6297815 commit 06c8996

File tree

2 files changed

+29
-20
lines changed

2 files changed

+29
-20
lines changed

sklearn/model_selection/_search.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# License: BSD 3 clause
1313

1414
from abc import ABCMeta, abstractmethod
15-
from collections import Mapping, namedtuple, Sized, defaultdict
15+
from collections import Mapping, namedtuple, Sized, defaultdict, Sequence
1616
from functools import partial, reduce
1717
from itertools import product
1818
import operator
@@ -332,7 +332,7 @@ def _check_param_grid(param_grid):
332332
if isinstance(v, np.ndarray) and v.ndim > 1:
333333
raise ValueError("Parameter array should be one-dimensional.")
334334

335-
check = [isinstance(v, k) for k in (list, tuple, np.ndarray)]
335+
check = [isinstance(v, k) for k in (np.ndarray, Sequence)]
336336
if True not in check:
337337
raise ValueError("Parameter values for parameter ({0}) need "
338338
"to be a sequence.".format(name))

sklearn/model_selection/tests/test_search.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from collections import Iterable, Sized
44
from sklearn.externals.six.moves import cStringIO as StringIO
55
from sklearn.externals.six.moves import xrange
6+
from sklearn.externals.joblib._compat import PY3_OR_LATER
67
from itertools import chain, product
78
import pickle
89
import sys
@@ -168,22 +169,6 @@ def test_grid_search():
168169
assert_raises(ValueError, grid_search.fit, X, y)
169170

170171

171-
def test_grid_search_incorrect_param_grid():
172-
clf = MockClassifier()
173-
assert_raise_message(
174-
ValueError,
175-
"Parameter values for parameter (C) need to be a sequence.",
176-
GridSearchCV, clf, {'C': 1})
177-
178-
179-
def test_grid_search_param_grid_includes_sequence_of_a_zero_length():
180-
clf = MockClassifier()
181-
assert_raise_message(
182-
ValueError,
183-
"Parameter values for parameter (C) need to be a non-empty sequence.",
184-
GridSearchCV, clf, {'C': []})
185-
186-
187172
@ignore_warnings
188173
def test_grid_search_no_score():
189174
# Test grid-search on classifier that has no score function.
@@ -319,14 +304,38 @@ def test_grid_search_one_grid_point():
319304
assert_array_equal(clf.dual_coef_, cv.best_estimator_.dual_coef_)
320305

321306

307+
def test_grid_search_when_param_grid_includes_range():
308+
# Test that the best estimator contains the right value for foo_param
309+
clf = MockClassifier()
310+
grid_search = None
311+
if PY3_OR_LATER:
312+
grid_search = GridSearchCV(clf, {'foo_param': range(1, 4)}, verbose=3)
313+
else:
314+
grid_search = GridSearchCV(clf, {'foo_param': xrange(1, 4)}, verbose=3)
315+
# make sure it selects the smallest parameter in case of ties
316+
old_stdout = sys.stdout
317+
sys.stdout = StringIO()
318+
grid_search.fit(X, y)
319+
sys.stdout = old_stdout
320+
assert_equal(grid_search.best_estimator_.foo_param, 2)
321+
322+
assert_array_equal(grid_search.results_["param_foo_param"].data, [1, 2, 3])
323+
324+
322325
def test_grid_search_bad_param_grid():
323326
param_dict = {"C": 1.0}
324327
clf = SVC()
325-
assert_raises(ValueError, GridSearchCV, clf, param_dict)
328+
assert_raise_message(
329+
ValueError,
330+
"Parameter values for parameter (C) need to be a sequence.",
331+
GridSearchCV, clf, param_dict)
326332

327333
param_dict = {"C": []}
328334
clf = SVC()
329-
assert_raises(ValueError, GridSearchCV, clf, param_dict)
335+
assert_raise_message(
336+
ValueError,
337+
"Parameter values for parameter (C) need to be a non-empty sequence.",
338+
GridSearchCV, clf, param_dict)
330339

331340
param_dict = {"C": np.ones(6).reshape(3, 2)}
332341
clf = SVC()

0 commit comments

Comments
 (0)