|
3 | 3 | from collections import Iterable, Sized |
4 | 4 | from sklearn.externals.six.moves import cStringIO as StringIO |
5 | 5 | from sklearn.externals.six.moves import xrange |
| 6 | +from sklearn.externals.joblib._compat import PY3_OR_LATER |
6 | 7 | from itertools import chain, product |
7 | 8 | import pickle |
8 | 9 | import sys |
@@ -168,22 +169,6 @@ def test_grid_search(): |
168 | 169 | assert_raises(ValueError, grid_search.fit, X, y) |
169 | 170 |
|
170 | 171 |
|
171 | | -def test_grid_search_incorrect_param_grid(): |
172 | | - clf = MockClassifier() |
173 | | - assert_raise_message( |
174 | | - ValueError, |
175 | | - "Parameter values for parameter (C) need to be a sequence.", |
176 | | - GridSearchCV, clf, {'C': 1}) |
177 | | - |
178 | | - |
179 | | -def test_grid_search_param_grid_includes_sequence_of_a_zero_length(): |
180 | | - clf = MockClassifier() |
181 | | - assert_raise_message( |
182 | | - ValueError, |
183 | | - "Parameter values for parameter (C) need to be a non-empty sequence.", |
184 | | - GridSearchCV, clf, {'C': []}) |
185 | | - |
186 | | - |
187 | 172 | @ignore_warnings |
188 | 173 | def test_grid_search_no_score(): |
189 | 174 | # Test grid-search on classifier that has no score function. |
@@ -319,14 +304,38 @@ def test_grid_search_one_grid_point(): |
319 | 304 | assert_array_equal(clf.dual_coef_, cv.best_estimator_.dual_coef_) |
320 | 305 |
|
321 | 306 |
|
| 307 | +def test_grid_search_when_param_grid_includes_range(): |
| 308 | + # Test that the best estimator contains the right value for foo_param |
| 309 | + clf = MockClassifier() |
| 310 | + grid_search = None |
| 311 | + if PY3_OR_LATER: |
| 312 | + grid_search = GridSearchCV(clf, {'foo_param': range(1, 4)}, verbose=3) |
| 313 | + else: |
| 314 | + grid_search = GridSearchCV(clf, {'foo_param': xrange(1, 4)}, verbose=3) |
| 315 | + # make sure it selects the smallest parameter in case of ties |
| 316 | + old_stdout = sys.stdout |
| 317 | + sys.stdout = StringIO() |
| 318 | + grid_search.fit(X, y) |
| 319 | + sys.stdout = old_stdout |
| 320 | + assert_equal(grid_search.best_estimator_.foo_param, 2) |
| 321 | + |
| 322 | + assert_array_equal(grid_search.results_["param_foo_param"].data, [1, 2, 3]) |
| 323 | + |
| 324 | + |
322 | 325 | def test_grid_search_bad_param_grid(): |
323 | 326 | param_dict = {"C": 1.0} |
324 | 327 | clf = SVC() |
325 | | - assert_raises(ValueError, GridSearchCV, clf, param_dict) |
| 328 | + assert_raise_message( |
| 329 | + ValueError, |
| 330 | + "Parameter values for parameter (C) need to be a sequence.", |
| 331 | + GridSearchCV, clf, param_dict) |
326 | 332 |
|
327 | 333 | param_dict = {"C": []} |
328 | 334 | clf = SVC() |
329 | | - assert_raises(ValueError, GridSearchCV, clf, param_dict) |
| 335 | + assert_raise_message( |
| 336 | + ValueError, |
| 337 | + "Parameter values for parameter (C) need to be a non-empty sequence.", |
| 338 | + GridSearchCV, clf, param_dict) |
330 | 339 |
|
331 | 340 | param_dict = {"C": np.ones(6).reshape(3, 2)} |
332 | 341 | clf = SVC() |
|
0 commit comments