Skip to content

Commit 60bd856

Browse files
committed
Merge pull request scikit-learn#3400 from arjoly/forest-test-oob
ENH improve forest testing + avoid *args
2 parents be5826d + 3209281 commit 60bd856

File tree

2 files changed

+117
-40
lines changed

2 files changed

+117
-40
lines changed

sklearn/ensemble/forest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1381,7 +1381,7 @@ def __init__(self,
13811381
warn("The min_density parameter is deprecated as of version 0.14 "
13821382
"and will be removed in 0.16.", DeprecationWarning)
13831383

1384-
def _set_oob_score(*args):
1384+
def _set_oob_score(self, X, y):
13851385
raise NotImplementedError("OOB score not supported by tree embedding")
13861386

13871387
def fit(self, X, y=None, sample_weight=None):

sklearn/ensemble/tests/test_forest.py

Lines changed: 116 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from sklearn.utils.testing import assert_false, assert_true
2121
from sklearn.utils.testing import assert_less, assert_greater
2222
from sklearn.utils.testing import assert_greater_equal
23+
from sklearn.utils.testing import assert_raises
24+
from sklearn.utils.testing import assert_warns
2325

2426
from sklearn import datasets
2527
from sklearn.decomposition import TruncatedSVD
@@ -221,8 +223,9 @@ def test_importances():
221223
def check_oob_score(name, X, y, n_estimators=20):
222224
"""Check that oob prediction is a good estimation of the generalization
223225
error."""
226+
# Proper behavior
224227
est = FOREST_ESTIMATORS[name](oob_score=True, random_state=0,
225-
n_estimators=n_estimators)
228+
n_estimators=n_estimators, bootstrap=True)
226229
n_samples = X.shape[0]
227230
est.fit(X[:n_samples // 2, :], y[:n_samples // 2])
228231
test_score = est.score(X[n_samples // 2:, :], y[n_samples // 2:])
@@ -233,15 +236,50 @@ def check_oob_score(name, X, y, n_estimators=20):
233236
assert_greater(test_score, est.oob_score_)
234237
assert_greater(est.oob_score_, .8)
235238

239+
# Check warning if not enough estimators
240+
with np.errstate(divide="ignore", invalid="ignore"):
241+
est = FOREST_ESTIMATORS[name](oob_score=True, random_state=0,
242+
n_estimators=1, bootstrap=True)
243+
assert_warns(UserWarning, est.fit, X, y)
244+
236245

237246
def test_oob_score():
238-
yield check_oob_score, "RandomForestClassifier", iris.data, iris.target
239-
yield (check_oob_score, "RandomForestRegressor", boston.data,
240-
boston.target, 50)
247+
for name in FOREST_CLASSIFIERS:
248+
yield check_oob_score, name, iris.data, iris.target
249+
250+
# non-contiguous targets in classification
251+
yield check_oob_score, name, iris.data, iris.target * 2 + 1
252+
253+
for name in FOREST_REGRESSORS:
254+
yield check_oob_score, name, boston.data, boston.target, 50
255+
241256

242-
# non-contiguous targets in classification
243-
yield (check_oob_score, "RandomForestClassifier", iris.data,
244-
iris.target * 2 + 1)
257+
def check_oob_score_raise_error(name):
258+
ForestEstimator = FOREST_ESTIMATORS[name]
259+
260+
if name in FOREST_TRANSFORMERS:
261+
for oob_score in [True, False]:
262+
assert_raises(TypeError, ForestEstimator, oob_score=oob_score)
263+
264+
assert_raises(NotImplementedError, ForestEstimator()._set_oob_score,
265+
X, y)
266+
267+
else:
268+
# Unfitted / no bootstrap / no oob_score
269+
for oob_score, bootstrap in [(True, False), (False, True),
270+
(False, False)]:
271+
est = ForestEstimator(oob_score=oob_score, bootstrap=bootstrap,
272+
random_state=0)
273+
assert_false(hasattr(est, "oob_score_"))
274+
275+
# No bootstrap
276+
assert_raises(ValueError, ForestEstimator(oob_score=True,
277+
bootstrap=False).fit, X, y)
278+
279+
280+
def test_oob_score_raise_error():
281+
for name in FOREST_ESTIMATORS:
282+
yield check_oob_score_raise_error, name
245283

246284

247285
def check_gridsearch(name):
@@ -302,8 +340,17 @@ def test_pickle():
302340
yield check_pickle, name, boston.data[::2], boston.target[::2]
303341

304342

305-
def check_multioutput(name, X_train, X_test, y_train, y_test):
343+
def check_multioutput(name):
306344
"""Check estimators on multi-output problems."""
345+
346+
X_train = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1], [-2, 1],
347+
[-1, 1], [-1, 2], [2, -1], [1, -1], [1, -2]]
348+
349+
y_train = [[-1, 0], [-1, 0], [-1, 0], [1, 1], [1, 1], [1, 1], [-1, 2],
350+
[-1, 2], [-1, 2], [1, 3], [1, 3], [1, 3]]
351+
X_test = [[-1, -1], [1, 1], [-1, 1], [1, -1]]
352+
y_test = [[-1, 0], [1, 1], [-1, 2], [1, 3]]
353+
307354
est = FOREST_ESTIMATORS[name](random_state=0, bootstrap=False)
308355
y_pred = est.fit(X_train, y_train).predict(X_test)
309356
assert_array_almost_equal(y_pred, y_test)
@@ -322,40 +369,11 @@ def check_multioutput(name, X_train, X_test, y_train, y_test):
322369

323370

324371
def test_multioutput():
325-
X_train = [[-2, -1],
326-
[-1, -1],
327-
[-1, -2],
328-
[1, 1],
329-
[1, 2],
330-
[2, 1],
331-
[-2, 1],
332-
[-1, 1],
333-
[-1, 2],
334-
[2, -1],
335-
[1, -1],
336-
[1, -2]]
337-
338-
y_train = [[-1, 0],
339-
[-1, 0],
340-
[-1, 0],
341-
[1, 1],
342-
[1, 1],
343-
[1, 1],
344-
[-1, 2],
345-
[-1, 2],
346-
[-1, 2],
347-
[1, 3],
348-
[1, 3],
349-
[1, 3]]
350-
351-
X_test = [[-1, -1], [1, 1], [-1, 1], [1, -1]]
352-
y_test = [[-1, 0], [1, 1], [-1, 2], [1, 3]]
353-
354372
for name in FOREST_CLASSIFIERS:
355-
yield check_multioutput, name, X_train, X_test, y_train, y_test
373+
yield check_multioutput, name
356374

357375
for name in FOREST_REGRESSORS:
358-
yield check_multioutput, name, X_train, X_test, y_train, y_test
376+
yield check_multioutput, name
359377

360378

361379
def check_classes_shape(name):
@@ -593,6 +611,65 @@ def test_min_weight_fraction_leaf():
593611
yield check_min_weight_fraction_leaf, name, X, y
594612

595613

614+
def check_memory_layout(name, dtype):
615+
"""Check that it works no matter the memory layout"""
616+
617+
est = FOREST_ESTIMATORS[name](random_state=0, bootstrap=False)
618+
619+
# Nothing
620+
X = np.asarray(iris.data, dtype=dtype)
621+
y = iris.target
622+
assert_array_equal(est.fit(X, y).predict(X), y)
623+
624+
# C-order
625+
X = np.asarray(iris.data, order="C", dtype=dtype)
626+
y = iris.target
627+
assert_array_equal(est.fit(X, y).predict(X), y)
628+
629+
# F-order
630+
X = np.asarray(iris.data, order="F", dtype=dtype)
631+
y = iris.target
632+
assert_array_equal(est.fit(X, y).predict(X), y)
633+
634+
# Contiguous
635+
X = np.ascontiguousarray(iris.data, dtype=dtype)
636+
y = iris.target
637+
assert_array_equal(est.fit(X, y).predict(X), y)
638+
639+
# Strided
640+
X = np.asarray(iris.data[::3], dtype=dtype)
641+
y = iris.target[::3]
642+
assert_array_equal(est.fit(X, y).predict(X), y)
643+
644+
645+
def test_memory_layout():
646+
for name, dtype in product(FOREST_CLASSIFIERS, [np.float64, np.float32]):
647+
yield check_memory_layout, name, dtype
648+
649+
for name, dtype in product(FOREST_REGRESSORS, [np.float64, np.float32]):
650+
yield check_memory_layout, name, dtype
651+
652+
653+
def check_1d_input(name, X, X_2d, y):
654+
ForestEstimator = FOREST_ESTIMATORS[name]
655+
assert_raises(ValueError, ForestEstimator(random_state=0).fit, X, y)
656+
657+
est = ForestEstimator(random_state=0)
658+
est.fit(X_2d, y)
659+
660+
if name in FOREST_CLASSIFIERS or name in FOREST_REGRESSORS:
661+
assert_raises(ValueError, est.predict, X)
662+
663+
664+
def test_1d_input():
665+
X = iris.data[:, 0].ravel()
666+
X_2d = iris.data[:, 0].reshape((-1, 1))
667+
y = iris.target
668+
669+
for name in FOREST_ESTIMATORS:
670+
yield check_1d_input, name, X, X_2d, y
671+
672+
596673
if __name__ == "__main__":
597674
import nose
598675
nose.runmodule()

0 commit comments

Comments
 (0)