2929from sklearn .utils .testing import assert_true
3030from sklearn .utils .testing import raises
3131from sklearn .utils .validation import check_random_state
32+ from sklearn .utils .validation import NotFittedError
3233
3334from sklearn .tree import DecisionTreeClassifier
3435from sklearn .tree import DecisionTreeRegressor
@@ -494,7 +495,7 @@ def test_error():
494495 for name , TreeEstimator in CLF_TREES .items ():
495496 # predict before fit
496497 est = TreeEstimator ()
497- assert_raises (Exception , est .predict_proba , X )
498+ assert_raises (NotFittedError , est .predict_proba , X )
498499
499500 est .fit (X , y )
500501 X2 = [- 2 , - 1 , 1 ] # wrong feature shape for sample
@@ -527,7 +528,7 @@ def test_error():
527528
528529 # predict before fitting
529530 est = TreeEstimator ()
530- assert_raises (Exception , est .predict , T )
531+ assert_raises (NotFittedError , est .predict , T )
531532
532533 # predict on vector with different dims
533534 est .fit (X , y )
@@ -545,6 +546,10 @@ def test_error():
545546 clf .fit (X , y )
546547 assert_raises (ValueError , clf .predict , Xt )
547548
549+ # apply before fitting
550+ est = TreeEstimator ()
551+ assert_raises (NotFittedError , est .apply , T )
552+
548553
549554def test_min_samples_leaf ():
550555 # Test if leaves contain more than leaf_count training examples
@@ -1208,6 +1213,8 @@ def check_explicit_sparse_zeros(tree, max_depth=3,
12081213 Xs = (X_test , X_sparse_test )
12091214 for X1 , X2 in product (Xs , Xs ):
12101215 assert_array_almost_equal (s .tree_ .apply (X1 ), d .tree_ .apply (X2 ))
1216+ assert_array_almost_equal (s .apply (X1 ), d .apply (X2 ))
1217+ assert_array_almost_equal (s .apply (X1 ), s .tree_ .apply (X1 ))
12111218 assert_array_almost_equal (s .predict (X1 ), d .predict (X2 ))
12121219
12131220 if tree in CLF_TREES :
@@ -1266,3 +1273,29 @@ def check_min_weight_leaf_split_level(name):
12661273def test_min_weight_leaf_split_level ():
12671274 for name in ALL_TREES :
12681275 yield check_min_weight_leaf_split_level , name
1276+
1277+
1278+ def check_public_apply (name ):
1279+ X_small32 = X_small .astype (tree ._tree .DTYPE )
1280+
1281+ est = ALL_TREES [name ]()
1282+ est .fit (X_small , y_small )
1283+ assert_array_equal (est .apply (X_small ),
1284+ est .tree_ .apply (X_small32 ))
1285+
1286+
1287+ def check_public_apply_sparse (name ):
1288+ X_small32 = csr_matrix (X_small .astype (tree ._tree .DTYPE ))
1289+
1290+ est = ALL_TREES [name ]()
1291+ est .fit (X_small , y_small )
1292+ assert_array_equal (est .apply (X_small ),
1293+ est .tree_ .apply (X_small32 ))
1294+
1295+
1296+ def test_public_apply ():
1297+ for name in ALL_TREES :
1298+ yield (check_public_apply , name )
1299+
1300+ for name in SPARSE_TREES :
1301+ yield (check_public_apply_sparse , name )
0 commit comments