Skip to content

Commit 2f563df

Browse files
betatimarjoly
authored andcommitted
Added warm_start to bagging
BaggingClassifier and BaggingRegressor now support warm_starts. Added basic tests and documentation of the new functionality. Heavily inspired by work on warm_start for Random forests.
1 parent 88e0c69 commit 2f563df

File tree

2 files changed

+114
-25
lines changed

2 files changed

+114
-25
lines changed

sklearn/ensemble/bagging.py

Lines changed: 60 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def _parallel_build_estimators(n_estimators, ensemble, X, y, sample_weight,
6565
print("building estimator %d of %d" % (i + 1, n_estimators))
6666

6767
random_state = check_random_state(seeds[i])
68-
seed = check_random_state(random_state.randint(MAX_INT))
68+
seed = random_state.randint(MAX_INT)
6969
estimator = ensemble._make_estimator(append=False)
7070

7171
try: # Not all estimator accept a random_state
@@ -206,6 +206,7 @@ def __init__(self,
206206
bootstrap=True,
207207
bootstrap_features=False,
208208
oob_score=False,
209+
warm_start=False,
209210
n_jobs=1,
210211
random_state=None,
211212
verbose=0):
@@ -218,6 +219,7 @@ def __init__(self,
218219
self.bootstrap = bootstrap
219220
self.bootstrap_features = bootstrap_features
220221
self.oob_score = oob_score
222+
self.warm_start = warm_start
221223
self.n_jobs = n_jobs
222224
self.random_state = random_state
223225
self.verbose = verbose
@@ -278,40 +280,60 @@ def fit(self, X, y, sample_weight=None):
278280
raise ValueError("Out of bag estimation only available"
279281
" if bootstrap=True")
280282

281-
# Free allocated memory, if any
282-
self.estimators_ = None
283+
if not self.warm_start:
284+
# Free allocated memory, if any
285+
self.estimators_ = []
286+
self.estimators_samples_ = []
287+
self.estiamtors_features_ = []
283288

284-
# Parallel loop
285-
n_jobs, n_estimators, starts = _partition_estimators(self.n_estimators,
286-
self.n_jobs)
287-
seeds = random_state.randint(MAX_INT, size=self.n_estimators)
289+
n_more_estimators = self.n_estimators - len(self.estimators_)
288290

289-
all_results = Parallel(n_jobs=n_jobs, verbose=self.verbose)(
290-
delayed(_parallel_build_estimators)(
291-
n_estimators[i],
292-
self,
293-
X,
294-
y,
295-
sample_weight,
296-
seeds[starts[i]:starts[i + 1]],
297-
verbose=self.verbose)
298-
for i in range(n_jobs))
291+
if n_more_estimators < 0:
292+
raise ValueError('n_estimators=%d must be larger or equal to '
293+
'len(estimators_)=%d when warm_start==True'
294+
% (self.n_estimators, len(self.estimators_)))
295+
296+
elif n_more_estimators == 0:
297+
warn("Warm-start fitting without increasing n_estimators does not "
298+
"fit new trees.")
299+
else:
300+
# Parallel loop
301+
n_jobs, n_estimators, starts = _partition_estimators(n_more_estimators,
302+
self.n_jobs)
303+
304+
# Advance random state to state after training
305+
# the first n_estimators
306+
if self.warm_start and len(self.estimators_) > 0:
307+
random_state.randint(MAX_INT, size=len(self.estimators_))
308+
309+
seeds = random_state.randint(MAX_INT, size=n_more_estimators)
310+
311+
all_results = Parallel(n_jobs=n_jobs, verbose=self.verbose)(
312+
delayed(_parallel_build_estimators)(
313+
n_estimators[i],
314+
self,
315+
X,
316+
y,
317+
sample_weight,
318+
seeds[starts[i]:starts[i + 1]],
319+
verbose=self.verbose)
320+
for i in range(n_jobs))
299321

300322
# Reduce
301-
self.estimators_ = list(itertools.chain.from_iterable(
323+
self.estimators_ += list(itertools.chain.from_iterable(
302324
t[0] for t in all_results))
303325
self.estimators_samples_ = list(itertools.chain.from_iterable(
304326
t[1] for t in all_results))
305327
self.estimators_features_ = list(itertools.chain.from_iterable(
306328
t[2] for t in all_results))
307329

308330
if self.oob_score:
309-
self._set_oob_score(X, y)
331+
self._set_oob_score(X, y, n_more_estimators)
310332

311333
return self
312334

313335
@abstractmethod
314-
def _set_oob_score(self, X, y):
336+
def _set_oob_score(self, X, y, n_more_estimators):
315337
"""Calculate out of bag predictions and score."""
316338

317339
def _validate_y(self, y):
@@ -368,6 +390,11 @@ class BaggingClassifier(BaseBagging, ClassifierMixin):
368390
Whether to use out-of-bag samples to estimate
369391
the generalization error.
370392
393+
warm_start : bool, optional (default=False)
394+
When set to True, reuse the solution of the previous call to fit
395+
and add more estimators to the ensemble, otherwise, just fit
396+
a whole new ensemble.
397+
371398
n_jobs : int, optional (default=1)
372399
The number of jobs to run in parallel for both `fit` and `predict`.
373400
If -1, then the number of jobs is set to the number of cores.
@@ -435,6 +462,7 @@ def __init__(self,
435462
bootstrap=True,
436463
bootstrap_features=False,
437464
oob_score=False,
465+
warm_start=False,
438466
n_jobs=1,
439467
random_state=None,
440468
verbose=0):
@@ -447,6 +475,7 @@ def __init__(self,
447475
bootstrap=bootstrap,
448476
bootstrap_features=bootstrap_features,
449477
oob_score=oob_score,
478+
warm_start=warm_start,
450479
n_jobs=n_jobs,
451480
random_state=random_state,
452481
verbose=verbose)
@@ -456,14 +485,14 @@ def _validate_estimator(self):
456485
super(BaggingClassifier, self)._validate_estimator(
457486
default=DecisionTreeClassifier())
458487

459-
def _set_oob_score(self, X, y):
488+
def _set_oob_score(self, X, y, n_more_estimators):
460489
n_classes_ = self.n_classes_
461490
classes_ = self.classes_
462491
n_samples = y.shape[0]
463492

464493
predictions = np.zeros((n_samples, n_classes_))
465494

466-
for estimator, samples, features in zip(self.estimators_,
495+
for estimator, samples, features in zip(self.estimators_[-n_more_estimators:],
467496
self.estimators_samples_,
468497
self.estimators_features_):
469498
mask = np.ones(n_samples, dtype=np.bool)
@@ -724,6 +753,11 @@ class BaggingRegressor(BaseBagging, RegressorMixin):
724753
Whether to use out-of-bag samples to estimate
725754
the generalization error.
726755
756+
warm_start : bool, optional (default=False)
757+
When set to True, reuse the solution of the previous call to fit
758+
and add more estimators to the ensemble, otherwise, just fit
759+
a whole new ensemble.
760+
727761
n_jobs : int, optional (default=1)
728762
The number of jobs to run in parallel for both `fit` and `predict`.
729763
If -1, then the number of jobs is set to the number of cores.
@@ -783,6 +817,7 @@ def __init__(self,
783817
bootstrap=True,
784818
bootstrap_features=False,
785819
oob_score=False,
820+
warm_start=False,
786821
n_jobs=1,
787822
random_state=None,
788823
verbose=0):
@@ -794,6 +829,7 @@ def __init__(self,
794829
bootstrap=bootstrap,
795830
bootstrap_features=bootstrap_features,
796831
oob_score=oob_score,
832+
warm_start=warm_start,
797833
n_jobs=n_jobs,
798834
random_state=random_state,
799835
verbose=verbose)
@@ -840,13 +876,13 @@ def _validate_estimator(self):
840876
super(BaggingRegressor, self)._validate_estimator(
841877
default=DecisionTreeRegressor())
842878

843-
def _set_oob_score(self, X, y):
879+
def _set_oob_score(self, X, y, n_more_estimators):
844880
n_samples = y.shape[0]
845881

846882
predictions = np.zeros((n_samples,))
847883
n_predictions = np.zeros((n_samples,))
848884

849-
for estimator, samples, features in zip(self.estimators_,
885+
for estimator, samples, features in zip(self.estimators_[-n_more_estimators:],
850886
self.estimators_samples_,
851887
self.estimators_features_):
852888
mask = np.ones(n_samples, dtype=np.bool)

sklearn/ensemble/tests/test_bagging.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from sklearn.pipeline import make_pipeline
3030
from sklearn.feature_selection import SelectKBest
3131
from sklearn.cross_validation import train_test_split
32-
from sklearn.datasets import load_boston, load_iris
32+
from sklearn.datasets import load_boston, load_iris, make_hastie_10_2
3333
from sklearn.utils import check_random_state
3434

3535
from scipy.sparse import csc_matrix, csr_matrix
@@ -571,6 +571,59 @@ def test_bagging_sample_weight_unsupported_but_passed():
571571
assert_raises(ValueError, estimator.fit, iris.data, iris.target,
572572
sample_weight=rng.randint(10, size=(iris.data.shape[0])))
573573

574+
575+
def test_warm_start(random_state=42):
576+
# Test if fitting incrementally with warm start gives a forest of the
577+
# right size and the same results as a normal fit.
578+
X, y = make_hastie_10_2(n_samples=20, random_state=1)
579+
580+
clf_ws = None
581+
for n_estimators in [5, 10]:
582+
if clf_ws is None:
583+
clf_ws = BaggingClassifier(n_estimators=n_estimators,
584+
random_state=random_state,
585+
warm_start=True)
586+
else:
587+
clf_ws.set_params(n_estimators=n_estimators)
588+
clf_ws.fit(X, y)
589+
assert_equal(len(clf_ws), n_estimators)
590+
591+
clf_no_ws = BaggingClassifier(n_estimators=10, random_state=random_state,
592+
warm_start=False)
593+
clf_no_ws.fit(X, y)
594+
595+
assert_equal(set([tree.random_state for tree in clf_ws]),
596+
set([tree.random_state for tree in clf_no_ws]))
597+
598+
599+
def test_warm_start_smaller_n_estimators():
600+
# Test if warm start'ed second fit with smaller n_estimators raises error.
601+
X, y = make_hastie_10_2(n_samples=20, random_state=1)
602+
clf = BaggingClassifier(n_estimators=5, warm_start=True)
603+
clf.fit(X, y)
604+
clf.set_params(n_estimators=4)
605+
assert_raises(ValueError, clf.fit, X, y)
606+
607+
608+
def test_warm_start_equivalence():
609+
# warm started classifier with 5+5 estimators should be equivalent to
610+
# one classifier with 10 estimators
611+
X, y = make_hastie_10_2(n_samples=20, random_state=1)
612+
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=43)
613+
614+
clf_ws = BaggingClassifier(n_estimators=5, warm_start=True)
615+
clf_ws.fit(X_train, y_train)
616+
clf_ws.set_params(n_estimators=10)
617+
clf_ws.fit(X_train, y_train)
618+
y1 = clf_ws.predict(X_test)
619+
620+
clf = BaggingClassifier(n_estimators=10, warm_start=False)
621+
clf.fit(X_train, y_train)
622+
y2 = clf.predict(X_test)
623+
624+
assert_array_almost_equal(y1, y2)
625+
626+
574627
if __name__ == "__main__":
575628
import nose
576629
nose.runmodule()

0 commit comments

Comments
 (0)