Skip to content

Commit cc74076

Browse files
glouppelarsmans
authored andcommitted
ENH: move _partition_estimators to ensemble.base
1 parent 74d3952 commit cc74076

File tree

3 files changed

+28
-49
lines changed

3 files changed

+28
-49
lines changed

sklearn/ensemble/bagging.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from inspect import getargspec
1414

1515
from ..base import ClassifierMixin, RegressorMixin
16-
from ..externals.joblib import Parallel, delayed, cpu_count
16+
from ..externals.joblib import Parallel, delayed
1717
from ..externals.six import with_metaclass
1818
from ..externals.six.moves import zip
1919
from ..metrics import r2_score, accuracy_score
@@ -22,7 +22,7 @@
2222
from ..utils.fixes import bincount, unique
2323
from ..utils.random import sample_without_replacement
2424

25-
from .base import BaseEnsemble
25+
from .base import BaseEnsemble, _partition_estimators
2626

2727
__all__ = ["BaggingClassifier",
2828
"BaggingRegressor"]
@@ -186,24 +186,6 @@ def _parallel_predict_regression(estimators, estimators_features, X):
186186
estimators_features))
187187

188188

189-
def _partition_estimators(ensemble):
190-
"""Private function used to partition estimators between jobs."""
191-
# Compute the number of jobs
192-
if ensemble.n_jobs == -1:
193-
n_jobs = min(cpu_count(), ensemble.n_estimators)
194-
195-
else:
196-
n_jobs = min(ensemble.n_jobs, ensemble.n_estimators)
197-
198-
# Partition estimators between jobs
199-
n_estimators = (ensemble.n_estimators // n_jobs) * np.ones(n_jobs,
200-
dtype=np.int)
201-
n_estimators[:ensemble.n_estimators % n_jobs] += 1
202-
starts = np.cumsum(n_estimators)
203-
204-
return n_jobs, n_estimators.tolist(), [0] + starts.tolist()
205-
206-
207189
class BaseBagging(with_metaclass(ABCMeta, BaseEnsemble)):
208190
"""Base class for Bagging meta-estimator.
209191

sklearn/ensemble/base.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55
# Authors: Gilles Louppe
66
# License: BSD 3 clause
77

8+
import numpy as np
9+
810
from ..base import clone
911
from ..base import BaseEstimator
1012
from ..base import MetaEstimatorMixin
13+
from ..externals.joblib import cpu_count
1114

1215

1316
class BaseEnsemble(BaseEstimator, MetaEstimatorMixin):
@@ -78,3 +81,21 @@ def __getitem__(self, index):
7881
def __iter__(self):
7982
"""Returns iterator over estimators in the ensemble."""
8083
return iter(self.estimators_)
84+
85+
86+
def _partition_estimators(ensemble):
87+
"""Private function used to partition estimators between jobs."""
88+
# Compute the number of jobs
89+
if ensemble.n_jobs == -1:
90+
n_jobs = min(cpu_count(), ensemble.n_estimators)
91+
92+
else:
93+
n_jobs = min(ensemble.n_jobs, ensemble.n_estimators)
94+
95+
# Partition estimators between jobs
96+
n_estimators = (ensemble.n_estimators // n_jobs) * np.ones(n_jobs,
97+
dtype=np.int)
98+
n_estimators[:ensemble.n_estimators % n_jobs] += 1
99+
starts = np.cumsum(n_estimators)
100+
101+
return n_jobs, n_estimators.tolist(), [0] + starts.tolist()

sklearn/ensemble/forest.py

Lines changed: 5 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class calls the ``fit`` method of each sub-estimator on random samples
4444
from abc import ABCMeta, abstractmethod
4545

4646
from ..base import ClassifierMixin, RegressorMixin
47-
from ..externals.joblib import Parallel, delayed, cpu_count
47+
from ..externals.joblib import Parallel, delayed
4848
from ..externals import six
4949
from ..externals.six.moves import xrange
5050
from ..feature_selection.from_model import _LearntSelectorMixin
@@ -57,8 +57,7 @@ class calls the ``fit`` method of each sub-estimator on random samples
5757
from ..utils.validation import DataConversionWarning
5858
from ..utils.fixes import bincount, unique
5959

60-
61-
from .base import BaseEnsemble
60+
from .base import BaseEnsemble, _partition_estimators
6261

6362
__all__ = ["RandomForestClassifier",
6463
"RandomForestRegressor",
@@ -151,29 +150,6 @@ def _parallel_predict_regression(trees, X):
151150
return sum(tree.predict(X) for tree in trees)
152151

153152

154-
def _partition_trees(forest):
155-
"""Private function used to partition trees between jobs."""
156-
# Compute the number of jobs
157-
if forest.n_jobs == -1:
158-
n_jobs = min(cpu_count(), forest.n_estimators)
159-
160-
else:
161-
n_jobs = min(forest.n_jobs, forest.n_estimators)
162-
163-
# Partition trees between jobs
164-
n_trees = [forest.n_estimators // n_jobs] * n_jobs
165-
166-
for i in range(forest.n_estimators % n_jobs):
167-
n_trees[i] += 1
168-
169-
starts = [0] * (n_jobs + 1)
170-
171-
for i in range(1, n_jobs + 1):
172-
starts[i] = starts[i - 1] + n_trees[i - 1]
173-
174-
return n_jobs, n_trees, starts
175-
176-
177153
class BaseForest(six.with_metaclass(ABCMeta, BaseEnsemble,
178154
_LearntSelectorMixin)):
179155
"""Base class for forests of trees.
@@ -286,7 +262,7 @@ def fit(self, X, y, sample_weight=None):
286262
" if bootstrap=True")
287263

288264
# Assign chunk of trees to jobs
289-
n_jobs, n_trees, _ = _partition_trees(self)
265+
n_jobs, n_trees, _ = _partition_estimators(self)
290266

291267
# Precalculate the random states
292268
seeds = [random_state.randint(MAX_INT, size=i) for i in n_trees]
@@ -481,7 +457,7 @@ def predict_proba(self, X):
481457
X = array2d(X, dtype=DTYPE)
482458

483459
# Assign chunk of trees to jobs
484-
n_jobs, n_trees, starts = _partition_trees(self)
460+
n_jobs, n_trees, starts = _partition_estimators(self)
485461

486462
# Parallel loop
487463
all_proba = Parallel(n_jobs=n_jobs, verbose=self.verbose)(
@@ -590,7 +566,7 @@ def predict(self, X):
590566
X = array2d(X, dtype=DTYPE)
591567

592568
# Assign chunk of trees to jobs
593-
n_jobs, n_trees, starts = _partition_trees(self)
569+
n_jobs, n_trees, starts = _partition_estimators(self)
594570

595571
# Parallel loop
596572
all_y_hat = Parallel(n_jobs=n_jobs, verbose=self.verbose)(

0 commit comments

Comments
 (0)