Skip to content

Commit 0a5af0d

Browse files
NicolasHugjnothmanamueller
authored
FEA Successive halving for faster parameter search (scikit-learn#13900)
* More flexible grid search interface * added info dict parameter * Put back removed test * renamed info into more_results * Passed grroups as well since we need n_to use get_n_splits(X, y, groups) * port * pep8 * dabl -> sklearn * add _required_parameters * skipping check in rst file if pandas not installed * Update sklearn/model_selection/_search_successive_halving.py Co-Authored-By: Joel Nothman <[email protected]> * renamed into GridHalvingSearchCV and RandomHalvingSearchCV * Addressed thomas' comments * repr * removed passing group as a parameter to evaluate_candidates * Joels comments * pep8 * reorganized user user guide * renaming * update user guide * remove groups support + pass fit_params * parameter renaming * pep8 * r_i -> resource_iter * fixed r_i issues * examples + removed use of word budget * Added inpute checking tests * added cv_resutlts_ user guide * minor title change * fixed doc layout * Addressed some comments * properly pass down fit_params * change default value of force_exhaust_resources and update doc * should fix doc * Used check_fit_params * Update section about min_resources and number of candidates * Clarified ratio section * Use ~ to refer to classes * fixed doc checks * Apply suggestions from code review Co-authored-by: Joel Nothman <[email protected]> * Addressed easy comments from Joel * missed some * updated docstring of run_search * Used f strings instead of format * remove candidate duplication checks * fix example * Addressed easy comments * rotate ticks labels * Added discussion in the intro as suggested by Joel * Split examples into sections * minor changes * remove force_exhaust_budget and introduce min_resources=exhaust * some minor validation * Added a n_resources_ attribute * update examples * Addressed comments * passing CV instead of X,y * minor revert for handling fit_params * updated docs * fix len * whatsnew * Add test for sampling when all_list * minor change to top-k * Force CV splits to be consistent across calls * reorder parameters * reduced diff * added tests for top_k * put back doc for groups * not sure what went wrong * put import at its place * some comment * Addressed comments * Added tests for cv_results_ and base estimator inputs * pep8 * avoid monkeypatching * rename df * use Joel's suggestions for testing masks * Made it experimental * Should fix docs * whats new entry * Apply suggestions from code review Co-authored-by: Andreas Mueller <[email protected]> * Addressed comments to docs * Addressed comments in examples * minor doc update * minor renaming in UG * forgot some * some sad note about splitter statefulness :'( * Addressed comments * ratio -> factor Co-authored-by: Joel Nothman <[email protected]> Co-authored-by: Andreas Mueller <[email protected]>
1 parent 5b29166 commit 0a5af0d

File tree

16 files changed

+2276
-29
lines changed

16 files changed

+2276
-29
lines changed

doc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@ def __call__(self, directory):
356356
# discovered properly by sphinx
357357
from sklearn.experimental import enable_hist_gradient_boosting # noqa
358358
from sklearn.experimental import enable_iterative_imputer # noqa
359+
from sklearn.experimental import enable_successive_halving # noqa
359360

360361

361362
def make_carousel_thumbs(app, exception):

doc/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,13 @@ def setup_impute():
5757
raise SkipTest("Skipping impute.rst, pandas not installed")
5858

5959

60+
def setup_grid_search():
61+
try:
62+
import pandas # noqa
63+
except ImportError:
64+
raise SkipTest("Skipping grid_search.rst, pandas not installed")
65+
66+
6067
def setup_unsupervised_learning():
6168
try:
6269
import skimage # noqa
@@ -86,5 +93,7 @@ def pytest_runtest_setup(item):
8693
raise SkipTest('FeatureHasher is not compatible with PyPy')
8794
elif fname.endswith('modules/impute.rst'):
8895
setup_impute()
96+
elif fname.endswith('modules/grid_search.rst'):
97+
setup_grid_search()
8998
elif fname.endswith('statistical_inference/unsupervised_learning.rst'):
9099
setup_unsupervised_learning()

doc/modules/classes.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1194,9 +1194,11 @@ Hyper-parameter optimizers
11941194
:template: class.rst
11951195

11961196
model_selection.GridSearchCV
1197+
model_selection.HalvingGridSearchCV
11971198
model_selection.ParameterGrid
11981199
model_selection.ParameterSampler
11991200
model_selection.RandomizedSearchCV
1201+
model_selection.HalvingRandomSearchCV
12001202

12011203

12021204
Model validation

doc/modules/grid_search.rst

Lines changed: 390 additions & 16 deletions
Large diffs are not rendered by default.

doc/whats_new/v0.24.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,14 @@ Changelog
412412
:pr:`17478` by :user:`Teon Brooks <teonbrooks>` and
413413
:user:`Mohamed Maskani <maskani-moh>`.
414414

415+
- |Feature| Added (experimental) parameter search estimators
416+
:class:`model_selection.HalvingRandomSearchCV` and
417+
:class:`model_selection.HalvingGridSearchCV` which implement Successive
418+
Halving, and can be used as a drop-in replacements for
419+
:class:`model_selection.RandomizedSearchCV` and
420+
:class:`model_selection.GridSearchCV`. :pr:`13900` by `Nicolas Hug`_, `Joel
421+
Nothman`_ and `Andreas Müller`_.
422+
415423
- |Fix| Fixed the `len` of :class:`model_selection.ParameterSampler` when
416424
all distributions are lists and `n_iter` is more than the number of unique
417425
parameter combinations. :pr:`18222` by `Nicolas Hug`_.
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
"""
2+
Comparison between grid search and successive halving
3+
=====================================================
4+
5+
This example compares the parameter search performed by
6+
:class:`~sklearn.model_selection.HalvingGridSearchCV` and
7+
:class:`~sklearn.model_selection.GridSearchCV`.
8+
9+
"""
10+
from time import time
11+
12+
import matplotlib.pyplot as plt
13+
import numpy as np
14+
import pandas as pd
15+
16+
from sklearn.svm import SVC
17+
from sklearn import datasets
18+
from sklearn.model_selection import GridSearchCV
19+
from sklearn.experimental import enable_successive_halving # noqa
20+
from sklearn.model_selection import HalvingGridSearchCV
21+
22+
23+
print(__doc__)
24+
25+
# %%
26+
# We first define the parameter space for an :class:`~sklearn.svm.SVC`
27+
# estimator, and compute the time required to train a
28+
# :class:`~sklearn.model_selection.HalvingGridSearchCV` instance, as well as a
29+
# :class:`~sklearn.model_selection.GridSearchCV` instance.
30+
31+
rng = np.random.RandomState(0)
32+
X, y = datasets.make_classification(n_samples=1000, random_state=rng)
33+
34+
gammas = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7]
35+
Cs = [1, 10, 100, 1e3, 1e4, 1e5]
36+
param_grid = {'gamma': gammas, 'C': Cs}
37+
38+
clf = SVC(random_state=rng)
39+
40+
tic = time()
41+
gsh = HalvingGridSearchCV(estimator=clf, param_grid=param_grid, factor=2,
42+
random_state=rng)
43+
gsh.fit(X, y)
44+
gsh_time = time() - tic
45+
46+
tic = time()
47+
gs = GridSearchCV(estimator=clf, param_grid=param_grid)
48+
gs.fit(X, y)
49+
gs_time = time() - tic
50+
51+
# %%
52+
# We now plot heatmaps for both search estimators.
53+
54+
55+
def make_heatmap(ax, gs, is_sh=False, make_cbar=False):
56+
"""Helper to make a heatmap."""
57+
results = pd.DataFrame.from_dict(gs.cv_results_)
58+
results['params_str'] = results.params.apply(str)
59+
if is_sh:
60+
# SH dataframe: get mean_test_score values for the highest iter
61+
scores_matrix = results.sort_values('iter').pivot_table(
62+
index='param_gamma', columns='param_C',
63+
values='mean_test_score', aggfunc='last'
64+
)
65+
else:
66+
scores_matrix = results.pivot(index='param_gamma', columns='param_C',
67+
values='mean_test_score')
68+
69+
im = ax.imshow(scores_matrix)
70+
71+
ax.set_xticks(np.arange(len(Cs)))
72+
ax.set_xticklabels(['{:.0E}'.format(x) for x in Cs])
73+
ax.set_xlabel('C', fontsize=15)
74+
75+
ax.set_yticks(np.arange(len(gammas)))
76+
ax.set_yticklabels(['{:.0E}'.format(x) for x in gammas])
77+
ax.set_ylabel('gamma', fontsize=15)
78+
79+
# Rotate the tick labels and set their alignment.
80+
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
81+
rotation_mode="anchor")
82+
83+
if is_sh:
84+
iterations = results.pivot_table(index='param_gamma',
85+
columns='param_C', values='iter',
86+
aggfunc='max').values
87+
for i in range(len(gammas)):
88+
for j in range(len(Cs)):
89+
ax.text(j, i, iterations[i, j],
90+
ha="center", va="center", color="w", fontsize=20)
91+
92+
if make_cbar:
93+
fig.subplots_adjust(right=0.8)
94+
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
95+
fig.colorbar(im, cax=cbar_ax)
96+
cbar_ax.set_ylabel('mean_test_score', rotation=-90, va="bottom",
97+
fontsize=15)
98+
99+
100+
fig, axes = plt.subplots(ncols=2, sharey=True)
101+
ax1, ax2 = axes
102+
103+
make_heatmap(ax1, gsh, is_sh=True)
104+
make_heatmap(ax2, gs, make_cbar=True)
105+
106+
ax1.set_title('Successive Halving\ntime = {:.3f}s'.format(gsh_time),
107+
fontsize=15)
108+
ax2.set_title('GridSearch\ntime = {:.3f}s'.format(gs_time), fontsize=15)
109+
110+
plt.show()
111+
112+
# %%
113+
# The heatmaps show the mean test score of the parameter combinations for an
114+
# :class:`~sklearn.svm.SVC` instance. The
115+
# :class:`~sklearn.model_selection.HalvingGridSearchCV` also shows the
116+
# iteration at which the combinations where last used. The combinations marked
117+
# as ``0`` were only evaluated at the first iteration, while the ones with
118+
# ``5`` are the parameter combinations that are considered the best ones.
119+
#
120+
# We can see that the :class:`~sklearn.model_selection.HalvingGridSearchCV`
121+
# class is able to find parameter combinations that are just as accurate as
122+
# :class:`~sklearn.model_selection.GridSearchCV`, in much less time.
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
"""
2+
Successive Halving Iterations
3+
=============================
4+
5+
This example illustrates how a successive halving search (
6+
:class:`~sklearn.model_selection.HalvingGridSearchCV` and
7+
:class:`~sklearn.model_selection.HalvingRandomSearchCV`) iteratively chooses
8+
the best parameter combination out of multiple candidates.
9+
10+
"""
11+
import pandas as pd
12+
from sklearn import datasets
13+
import matplotlib.pyplot as plt
14+
from scipy.stats import randint
15+
import numpy as np
16+
17+
from sklearn.experimental import enable_successive_halving # noqa
18+
from sklearn.model_selection import HalvingRandomSearchCV
19+
from sklearn.ensemble import RandomForestClassifier
20+
21+
22+
print(__doc__)
23+
24+
# %%
25+
# We first define the parameter space and train a
26+
# :class:`~sklearn.model_selection.HalvingRandomSearchCV` instance.
27+
28+
rng = np.random.RandomState(0)
29+
30+
X, y = datasets.make_classification(n_samples=700, random_state=rng)
31+
32+
clf = RandomForestClassifier(n_estimators=20, random_state=rng)
33+
34+
param_dist = {"max_depth": [3, None],
35+
"max_features": randint(1, 11),
36+
"min_samples_split": randint(2, 11),
37+
"bootstrap": [True, False],
38+
"criterion": ["gini", "entropy"]}
39+
40+
rsh = HalvingRandomSearchCV(
41+
estimator=clf,
42+
param_distributions=param_dist,
43+
factor=2,
44+
random_state=rng)
45+
rsh.fit(X, y)
46+
47+
# %%
48+
# We can now use the `cv_results_` attribute of the search estimator to inspect
49+
# and plot the evolution of the search.
50+
51+
results = pd.DataFrame(rsh.cv_results_)
52+
results['params_str'] = results.params.apply(str)
53+
results.drop_duplicates(subset=('params_str', 'iter'), inplace=True)
54+
mean_scores = results.pivot(index='iter', columns='params_str',
55+
values='mean_test_score')
56+
ax = mean_scores.plot(legend=False, alpha=.6)
57+
58+
labels = [
59+
f'iter={i}\nn_samples={rsh.n_resources_[i]}\n'
60+
f'n_candidates={rsh.n_candidates_[i]}'
61+
for i in range(rsh.n_iterations_)
62+
]
63+
ax.set_xticklabels(labels, rotation=45, multialignment='left')
64+
ax.set_title('Scores of candidates over iterations')
65+
ax.set_ylabel('mean test score', fontsize=15)
66+
ax.set_xlabel('iterations', fontsize=15)
67+
plt.tight_layout()
68+
plt.show()
69+
70+
# %%
71+
# Number of candidates and amount of resource at each iteration
72+
# -------------------------------------------------------------
73+
#
74+
# At the first iteration, a small amount of resources is used. The resource
75+
# here is the number of samples that the estimators are trained on. All
76+
# candidates are evaluated.
77+
#
78+
# At the second iteration, only the best half of the candidates is evaluated.
79+
# The number of allocated resources is doubled: candidates are evaluated on
80+
# twice as many samples.
81+
#
82+
# This process is repeated until the last iteration, where only 2 candidates
83+
# are left. The best candidate is the candidate that has the best score at the
84+
# last iteration.
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""Enables Successive Halving search-estimators
2+
3+
The API and results of these estimators might change without any deprecation
4+
cycle.
5+
6+
Importing this file dynamically sets the
7+
:class:`~sklearn.model_selection.HalvingRandomSearchCV` and
8+
:class:`~sklearn.model_selection.HalvingGridSearchCV` as attributes of the
9+
`model_selection` module::
10+
11+
>>> # explicitly require this experimental feature
12+
>>> from sklearn.experimental import enable_successive_halving # noqa
13+
>>> # now you can import normally from model_selection
14+
>>> from sklearn.model_selection import HalvingRandomSearchCV
15+
>>> from sklearn.model_selection import HalvingGridSearchCV
16+
17+
18+
The ``# noqa`` comment comment can be removed: it just tells linters like
19+
flake8 to ignore the import, which appears as unused.
20+
"""
21+
22+
from ..model_selection._search_successive_halving import (
23+
HalvingRandomSearchCV,
24+
HalvingGridSearchCV
25+
)
26+
27+
from .. import model_selection
28+
29+
# use settattr to avoid mypy errors when monkeypatching
30+
setattr(model_selection, "HalvingRandomSearchCV",
31+
HalvingRandomSearchCV)
32+
setattr(model_selection, "HalvingGridSearchCV",
33+
HalvingGridSearchCV)
34+
35+
model_selection.__all__ += ['HalvingRandomSearchCV', 'HalvingGridSearchCV']
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""Tests for making sure experimental imports work as expected."""
2+
3+
import textwrap
4+
5+
from sklearn.utils._testing import assert_run_python_script
6+
7+
8+
def test_imports_strategies():
9+
# Make sure different import strategies work or fail as expected.
10+
11+
# Since Python caches the imported modules, we need to run a child process
12+
# for every test case. Else, the tests would not be independent
13+
# (manually removing the imports from the cache (sys.modules) is not
14+
# recommended and can lead to many complications).
15+
16+
good_import = """
17+
from sklearn.experimental import enable_successive_halving
18+
from sklearn.model_selection import HalvingGridSearchCV
19+
from sklearn.model_selection import HalvingRandomSearchCV
20+
"""
21+
assert_run_python_script(textwrap.dedent(good_import))
22+
23+
good_import_with_model_selection_first = """
24+
import sklearn.model_selection
25+
from sklearn.experimental import enable_successive_halving
26+
from sklearn.model_selection import HalvingGridSearchCV
27+
from sklearn.model_selection import HalvingRandomSearchCV
28+
"""
29+
assert_run_python_script(
30+
textwrap.dedent(good_import_with_model_selection_first)
31+
)
32+
33+
bad_imports = """
34+
import pytest
35+
36+
with pytest.raises(ImportError):
37+
from sklearn.model_selection import HalvingGridSearchCV
38+
39+
import sklearn.experimental
40+
with pytest.raises(ImportError):
41+
from sklearn.model_selection import HalvingGridSearchCV
42+
"""
43+
assert_run_python_script(textwrap.dedent(bad_imports))

sklearn/model_selection/__init__.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import typing
2+
13
from ._split import BaseCrossValidator
24
from ._split import KFold
35
from ._split import GroupKFold
@@ -29,7 +31,15 @@
2931
from ._search import ParameterSampler
3032
from ._search import fit_grid_point
3133

32-
__all__ = ('BaseCrossValidator',
34+
if typing.TYPE_CHECKING:
35+
# Avoid errors in type checkers (e.g. mypy) for experimental estimators.
36+
# TODO: remove this check once the estimator is no longer experimental.
37+
from ._search_successive_halving import ( # noqa
38+
HalvingGridSearchCV, HalvingRandomSearchCV
39+
)
40+
41+
42+
__all__ = ['BaseCrossValidator',
3343
'GridSearchCV',
3444
'TimeSeriesSplit',
3545
'KFold',
@@ -56,4 +66,4 @@
5666
'learning_curve',
5767
'permutation_test_score',
5868
'train_test_split',
59-
'validation_curve')
69+
'validation_curve']

0 commit comments

Comments
 (0)