Skip to content

Commit 7488fc0

Browse files
committed
Merge pull request scikit-learn#4840 from jnothman/dynamicgrid
[MRG] FIX avoid memory cost when sampling from large parameter grids
2 parents 1c07dec + febefb0 commit 7488fc0

File tree

3 files changed

+61
-5
lines changed

3 files changed

+61
-5
lines changed

doc/whats_new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ Bug fixes
9797
- Fixed bug in :class:`decomposition.DictLearning` when ``n_jobs < 0``. By
9898
`Andreas Müller`_.
9999

100+
- Fixed bug where :class:`grid_search.RandomizedSearchCV` could consume a
101+
lot of memory for large discrete grids. By `Joel Nothman`_.
102+
100103
API changes summary
101104
-------------------
102105

sklearn/grid_search.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ class ParameterGrid(object):
7070
... {'kernel': 'rbf', 'gamma': 1},
7171
... {'kernel': 'rbf', 'gamma': 10}]
7272
True
73+
>>> ParameterGrid(grid)[1] == {'kernel': 'rbf', 'gamma': 1}
74+
True
7375
7476
See also
7577
--------
@@ -111,6 +113,47 @@ def __len__(self):
111113
return sum(product(len(v) for v in p.values()) if p else 1
112114
for p in self.param_grid)
113115

116+
def __getitem__(self, ind):
117+
"""Get the parameters that would be ``ind``th in iteration
118+
119+
Parameters
120+
----------
121+
ind : int
122+
The iteration index
123+
124+
Returns
125+
-------
126+
params : dict of string to any
127+
Equal to list(self)[ind]
128+
"""
129+
# This is used to make discrete sampling without replacement memory
130+
# efficient.
131+
for sub_grid in self.param_grid:
132+
# XXX: could memoize information used here
133+
if not sub_grid:
134+
if ind == 0:
135+
return {}
136+
else:
137+
ind -= 1
138+
continue
139+
140+
# Reverse so most frequent cycling parameter comes first
141+
keys, values_lists = zip(*sorted(sub_grid.items())[::-1])
142+
sizes = [len(v_list) for v_list in values_lists]
143+
total = np.product(sizes)
144+
145+
if ind >= total:
146+
# Try the next grid
147+
ind -= total
148+
else:
149+
out = {}
150+
for key, v_list, n in zip(keys, values_lists, sizes):
151+
ind, offset = divmod(ind, n)
152+
out[key] = v_list[offset]
153+
return out
154+
155+
raise IndexError('ParameterGrid index out of range')
156+
114157

115158
class ParameterSampler(object):
116159
"""Generator on parameters sampled from given distributions.
@@ -181,8 +224,8 @@ def __iter__(self):
181224
rnd = check_random_state(self.random_state)
182225

183226
if all_lists:
184-
# get complete grid and yield from it
185-
param_grid = list(ParameterGrid(self.param_distributions))
227+
# look up sampled parameter settings in parameter grid
228+
param_grid = ParameterGrid(self.param_distributions)
186229
grid_size = len(param_grid)
187230

188231
if grid_size < self.n_iter:

sklearn/tests/test_grid_search.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,18 @@ def score(self):
9292
y = np.array([1, 1, 2, 2])
9393

9494

95+
def assert_grid_iter_equals_getitem(grid):
96+
assert_equal(list(grid), [grid[i] for i in range(len(grid))])
97+
98+
9599
def test_parameter_grid():
96100
# Test basic properties of ParameterGrid.
97101
params1 = {"foo": [1, 2, 3]}
98102
grid1 = ParameterGrid(params1)
99103
assert_true(isinstance(grid1, Iterable))
100104
assert_true(isinstance(grid1, Sized))
101105
assert_equal(len(grid1), 3)
106+
assert_grid_iter_equals_getitem(grid1)
102107

103108
params2 = {"foo": [4, 2],
104109
"bar": ["ham", "spam", "eggs"]}
@@ -113,14 +118,19 @@ def test_parameter_grid():
113118
set(("bar", x, "foo", y)
114119
for x, y in product(params2["bar"], params2["foo"])))
115120

121+
assert_grid_iter_equals_getitem(grid2)
122+
116123
# Special case: empty grid (useful to get default estimator settings)
117124
empty = ParameterGrid({})
118125
assert_equal(len(empty), 1)
119126
assert_equal(list(empty), [{}])
127+
assert_grid_iter_equals_getitem(empty)
128+
assert_raises(IndexError, lambda: empty[1])
120129

121-
has_empty = ParameterGrid([{'C': [1, 10]}, {}])
122-
assert_equal(len(has_empty), 3)
123-
assert_equal(list(has_empty), [{'C': 1}, {'C': 10}, {}])
130+
has_empty = ParameterGrid([{'C': [1, 10]}, {}, {'C': [.5]}])
131+
assert_equal(len(has_empty), 4)
132+
assert_equal(list(has_empty), [{'C': 1}, {'C': 10}, {}, {'C': .5}])
133+
assert_grid_iter_equals_getitem(has_empty)
124134

125135

126136
def test_grid_search():

0 commit comments

Comments
 (0)