Skip to content

Commit 6f1c384

Browse files
committed
Merge pull request scikit-learn#773 from amueller/forest_pre_dispatch
MRG pre_dispatch for foresters
2 parents a5a02d4 + a6f0e2a commit 6f1c384

File tree

1 file changed

+97
-14
lines changed

1 file changed

+97
-14
lines changed

sklearn/ensemble/forest.py

Lines changed: 97 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,8 @@ def __init__(self, base_estimator,
172172
oob_score=False,
173173
n_jobs=1,
174174
random_state=None,
175-
verbose=0):
175+
verbose=0,
176+
pre_dispatch="2*n_jobs"):
176177
super(BaseForest, self).__init__(
177178
base_estimator=base_estimator,
178179
n_estimators=n_estimators,
@@ -182,6 +183,7 @@ def __init__(self, base_estimator,
182183
self.compute_importances = compute_importances
183184
self.oob_score = oob_score
184185
self.n_jobs = n_jobs
186+
self.pre_dispatch = pre_dispatch
185187
self.random_state = check_random_state(random_state)
186188

187189
self.feature_importances_ = None
@@ -237,7 +239,8 @@ def fit(self, X, y):
237239
n_jobs, n_trees, _ = _partition_trees(self)
238240

239241
# Parallel loop
240-
all_trees = Parallel(n_jobs=n_jobs, verbose=self.verbose)(
242+
all_trees = Parallel(n_jobs=n_jobs, verbose=self.verbose,
243+
pre_dispatch=self.pre_dispatch)(
241244
delayed(_parallel_build_trees)(
242245
n_trees[i],
243246
self,
@@ -302,7 +305,8 @@ def __init__(self, base_estimator,
302305
oob_score=False,
303306
n_jobs=1,
304307
random_state=None,
305-
verbose=0):
308+
verbose=0,
309+
pre_dispatch="2*n_jobs"):
306310
super(ForestClassifier, self).__init__(
307311
base_estimator,
308312
n_estimators=n_estimators,
@@ -312,7 +316,8 @@ def __init__(self, base_estimator,
312316
oob_score=oob_score,
313317
n_jobs=n_jobs,
314318
random_state=random_state,
315-
verbose=verbose)
319+
verbose=verbose,
320+
pre_dispatch=pre_dispatch)
316321

317322
def predict(self, X):
318323
"""Predict class for X.
@@ -402,7 +407,8 @@ def __init__(self, base_estimator,
402407
oob_score=False,
403408
n_jobs=1,
404409
random_state=None,
405-
verbose=0):
410+
verbose=0,
411+
pre_dispatch="2*n_jobs"):
406412
super(ForestRegressor, self).__init__(
407413
base_estimator,
408414
n_estimators=n_estimators,
@@ -412,7 +418,8 @@ def __init__(self, base_estimator,
412418
oob_score=oob_score,
413419
n_jobs=n_jobs,
414420
random_state=random_state,
415-
verbose=verbose)
421+
verbose=verbose,
422+
pre_dispatch=pre_dispatch)
416423

417424
def predict(self, X):
418425
"""Predict regression target for X.
@@ -527,6 +534,23 @@ class RandomForestClassifier(ForestClassifier):
527534
verbose : int, optional (default=0)
528535
Controlls the verbosity of the tree building process.
529536
537+
pre_dispatch: int, or string, optional
538+
Controls the number of jobs that get dispatched during parallel
539+
execution. Reducing this number can be useful to avoid an
540+
explosion of memory consumption when more jobs get dispatched
541+
than CPUs can process. This parameter can be:
542+
543+
- None, in which case all the jobs are immediatly
544+
created and spawned. Use this for lightweight and
545+
fast-running jobs, to avoid delays due to on-demand
546+
spawning of the jobs
547+
548+
- An int, giving the exact number of total jobs that are
549+
spawned
550+
551+
- A string, giving an expression as a function of n_jobs,
552+
as in '2*n_jobs'
553+
530554
Attributes
531555
----------
532556
`feature_importances_` : array, shape = [n_features]
@@ -561,7 +585,8 @@ def __init__(self, n_estimators=10,
561585
oob_score=False,
562586
n_jobs=1,
563587
random_state=None,
564-
verbose=0):
588+
verbose=0,
589+
pre_dispatch="2*n_jobs"):
565590
super(RandomForestClassifier, self).__init__(
566591
base_estimator=DecisionTreeClassifier(),
567592
n_estimators=n_estimators,
@@ -573,7 +598,8 @@ def __init__(self, n_estimators=10,
573598
oob_score=oob_score,
574599
n_jobs=n_jobs,
575600
random_state=random_state,
576-
verbose=verbose)
601+
verbose=verbose,
602+
pre_dispatch=pre_dispatch)
577603

578604
self.criterion = criterion
579605
self.max_depth = max_depth
@@ -662,6 +688,23 @@ class RandomForestRegressor(ForestRegressor):
662688
verbose : int, optional (default=0)
663689
Controlls the verbosity of the tree building process.
664690
691+
pre_dispatch: int, or string, optional
692+
Controls the number of jobs that get dispatched during parallel
693+
execution. Reducing this number can be useful to avoid an
694+
explosion of memory consumption when more jobs get dispatched
695+
than CPUs can process. This parameter can be:
696+
697+
- None, in which case all the jobs are immediatly
698+
created and spawned. Use this for lightweight and
699+
fast-running jobs, to avoid delays due to on-demand
700+
spawning of the jobs
701+
702+
- An int, giving the exact number of total jobs that are
703+
spawned
704+
705+
- A string, giving an expression as a function of n_jobs,
706+
as in '2*n_jobs'
707+
665708
Attributes
666709
----------
667710
`feature_importances_` : array of shape = [n_features]
@@ -696,7 +739,8 @@ def __init__(self, n_estimators=10,
696739
oob_score=False,
697740
n_jobs=1,
698741
random_state=None,
699-
verbose=0):
742+
verbose=0,
743+
pre_dispatch="2*n_jobs"):
700744
super(RandomForestRegressor, self).__init__(
701745
base_estimator=DecisionTreeRegressor(),
702746
n_estimators=n_estimators,
@@ -708,7 +752,8 @@ def __init__(self, n_estimators=10,
708752
oob_score=oob_score,
709753
n_jobs=n_jobs,
710754
random_state=random_state,
711-
verbose=verbose)
755+
verbose=verbose,
756+
pre_dispatch=pre_dispatch)
712757

713758
self.criterion = criterion
714759
self.max_depth = max_depth
@@ -798,6 +843,23 @@ class ExtraTreesClassifier(ForestClassifier):
798843
verbose : int, optional (default=0)
799844
Controlls the verbosity of the tree building process.
800845
846+
pre_dispatch: int, or string, optional
847+
Controls the number of jobs that get dispatched during parallel
848+
execution. Reducing this number can be useful to avoid an
849+
explosion of memory consumption when more jobs get dispatched
850+
than CPUs can process. This parameter can be:
851+
852+
- None, in which case all the jobs are immediatly
853+
created and spawned. Use this for lightweight and
854+
fast-running jobs, to avoid delays due to on-demand
855+
spawning of the jobs
856+
857+
- An int, giving the exact number of total jobs that are
858+
spawned
859+
860+
- A string, giving an expression as a function of n_jobs,
861+
as in '2*n_jobs'
862+
801863
Attributes
802864
----------
803865
`feature_importances_` : array of shape = [n_features]
@@ -834,7 +896,8 @@ def __init__(self, n_estimators=10,
834896
oob_score=False,
835897
n_jobs=1,
836898
random_state=None,
837-
verbose=0):
899+
verbose=0,
900+
pre_dispatch="2*n_jobs"):
838901
super(ExtraTreesClassifier, self).__init__(
839902
base_estimator=ExtraTreeClassifier(),
840903
n_estimators=n_estimators,
@@ -846,7 +909,8 @@ def __init__(self, n_estimators=10,
846909
oob_score=oob_score,
847910
n_jobs=n_jobs,
848911
random_state=random_state,
849-
verbose=verbose)
912+
verbose=verbose,
913+
pre_dispatch=pre_dispatch)
850914

851915
self.criterion = criterion
852916
self.max_depth = max_depth
@@ -937,6 +1001,23 @@ class ExtraTreesRegressor(ForestRegressor):
9371001
verbose : int, optional (default=0)
9381002
Controlls the verbosity of the tree building process.
9391003
1004+
pre_dispatch: int, or string, optional
1005+
Controls the number of jobs that get dispatched during parallel
1006+
execution. Reducing this number can be useful to avoid an
1007+
explosion of memory consumption when more jobs get dispatched
1008+
than CPUs can process. This parameter can be:
1009+
1010+
- None, in which case all the jobs are immediatly
1011+
created and spawned. Use this for lightweight and
1012+
fast-running jobs, to avoid delays due to on-demand
1013+
spawning of the jobs
1014+
1015+
- An int, giving the exact number of total jobs that are
1016+
spawned
1017+
1018+
- A string, giving an expression as a function of n_jobs,
1019+
as in '2*n_jobs'
1020+
9401021
Attributes
9411022
----------
9421023
`feature_importances_` : array of shape = [n_features]
@@ -971,7 +1052,8 @@ def __init__(self, n_estimators=10,
9711052
oob_score=False,
9721053
n_jobs=1,
9731054
random_state=None,
974-
verbose=0):
1055+
verbose=0,
1056+
pre_dispatch="2*n_jobs"):
9751057
super(ExtraTreesRegressor, self).__init__(
9761058
base_estimator=ExtraTreeRegressor(),
9771059
n_estimators=n_estimators,
@@ -983,7 +1065,8 @@ def __init__(self, n_estimators=10,
9831065
oob_score=oob_score,
9841066
n_jobs=n_jobs,
9851067
random_state=random_state,
986-
verbose=verbose)
1068+
verbose=verbose,
1069+
pre_dispatch=pre_dispatch)
9871070

9881071
self.criterion = criterion
9891072
self.max_depth = max_depth

0 commit comments

Comments
 (0)