Skip to content

Commit 9be0922

Browse files
mikebenfieldjmschrei
authored andcommitted
[MRG+1] Fix excessive memory usage in random forest prediction (scikit-learn#8672)
[MRG+2] Fix excessive memory usage in random forest prediction
1 parent 3003da7 commit 9be0922

File tree

1 file changed

+31
-25
lines changed

1 file changed

+31
-25
lines changed

sklearn/ensemble/forest.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,19 @@ def feature_importances_(self):
374374
return sum(all_importances) / len(self.estimators_)
375375

376376

377+
# This is a utility function for joblib's Parallel. It can't go locally in
378+
# ForestClassifier or ForestRegressor, because joblib complains that it cannot
379+
# pickle it when placed there.
380+
381+
def accumulate_prediction(predict, X, out):
382+
prediction = predict(X, check_input=False)
383+
if len(out) == 1:
384+
out[0] += prediction
385+
else:
386+
for i in range(len(out)):
387+
out[i] += prediction[i]
388+
389+
377390
class ForestClassifier(six.with_metaclass(ABCMeta, BaseForest,
378391
ClassifierMixin)):
379392
"""Base class for forest of trees-based classifiers.
@@ -565,31 +578,20 @@ class in a leaf.
565578
# Assign chunk of trees to jobs
566579
n_jobs, _, _ = _partition_estimators(self.n_estimators, self.n_jobs)
567580

568-
# Parallel loop
569-
all_proba = Parallel(n_jobs=n_jobs, verbose=self.verbose,
570-
backend="threading")(
571-
delayed(parallel_helper)(e, 'predict_proba', X,
572-
check_input=False)
581+
# avoid storing the output of every estimator by summing them here
582+
all_proba = [np.zeros((X.shape[0], j), dtype=np.float64)
583+
for j in np.atleast_1d(self.n_classes_)]
584+
Parallel(n_jobs=n_jobs, verbose=self.verbose, backend="threading")(
585+
delayed(accumulate_prediction)(e.predict_proba, X, all_proba)
573586
for e in self.estimators_)
574587

575-
# Reduce
576-
proba = all_proba[0]
577-
578-
if self.n_outputs_ == 1:
579-
for j in range(1, len(all_proba)):
580-
proba += all_proba[j]
581-
588+
for proba in all_proba:
582589
proba /= len(self.estimators_)
583590

591+
if len(all_proba) == 1:
592+
return all_proba[0]
584593
else:
585-
for j in range(1, len(all_proba)):
586-
for k in range(self.n_outputs_):
587-
proba[k] += all_proba[j][k]
588-
589-
for k in range(self.n_outputs_):
590-
proba[k] /= self.n_estimators
591-
592-
return proba
594+
return all_proba
593595

594596
def predict_log_proba(self, X):
595597
"""Predict class log-probabilities for X.
@@ -678,14 +680,18 @@ def predict(self, X):
678680
# Assign chunk of trees to jobs
679681
n_jobs, _, _ = _partition_estimators(self.n_estimators, self.n_jobs)
680682

683+
# avoid storing the output of every estimator by summing them here
684+
if self.n_outputs_ > 1:
685+
y_hat = np.zeros((X.shape[0], self.n_outputs_), dtype=np.float64)
686+
else:
687+
y_hat = np.zeros((X.shape[0]), dtype=np.float64)
688+
681689
# Parallel loop
682-
all_y_hat = Parallel(n_jobs=n_jobs, verbose=self.verbose,
683-
backend="threading")(
684-
delayed(parallel_helper)(e, 'predict', X, check_input=False)
690+
Parallel(n_jobs=n_jobs, verbose=self.verbose, backend="threading")(
691+
delayed(accumulate_prediction)(e.predict, X, [y_hat])
685692
for e in self.estimators_)
686693

687-
# Reduce
688-
y_hat = sum(all_y_hat) / len(self.estimators_)
694+
y_hat /= len(self.estimators_)
689695

690696
return y_hat
691697

0 commit comments

Comments
 (0)