@@ -371,7 +371,7 @@ def prune(self, n_leaves):
371371
372372 """
373373
374- to_remove_count = self .node_count - len (self .leaves ) - n_leaves
374+ to_remove_count = self .node_count - len (self .leaves ) - n_leaves + 1
375375 nodes_to_remove = self .pruning_order (to_remove_count )
376376
377377 out_tree = self ._copy ()
@@ -1130,7 +1130,8 @@ def __init__(self, criterion="mse",
11301130 self .find_split_ = _tree ._find_best_random_split
11311131
11321132
1133- def cv_scores_vs_n_leaves (clf , X , y , max_n_leaves = 10 , cv = 10 ):
1133+ def cv_scores_vs_n_leaves (clf , X , y , max_n_leaves = 10 , n_iterations = 10 ,
1134+ test_size = 0.1 , random_state = None ):
11341135 """Cross validation of scores for different values of the decision tree.
11351136
11361137 This function allows to test what the optimal size of the decision tree
@@ -1151,8 +1152,16 @@ def cv_scores_vs_n_leaves(clf, X, y, max_n_leaves=10, cv=10):
11511152 max_n_leaves : int, optional (default=10)
11521153 maximum number of leaves of the tree to prune
11531154
1154- cv : int, optional (default=10)
1155- Size of the KFold cross validation generator
1155+ n_iterations : int, optional (default=10)
1156+ Number of re-shuffling & splitting iterations.
1157+
1158+ test_size : float (default=0.1) or int
1159+ If float, should be between 0.0 and 1.0 and represent the
1160+ proportion of the dataset to include in the test split. If
1161+ int, represents the absolute number of test samples.
1162+
1163+ random_state : int or RandomState
1164+ Pseudo-random number generator state used for random sampling.
11561165
11571166 Returns
11581167 -------
@@ -1164,11 +1173,12 @@ def cv_scores_vs_n_leaves(clf, X, y, max_n_leaves=10, cv=10):
11641173 """
11651174
11661175 from ..base import clone
1167- from ..cross_validation import KFold
1176+ from ..cross_validation import ShuffleSplit
11681177
11691178 scores = list ()
11701179
1171- kf = KFold (len (y ), cv )
1180+ kf = ShuffleSplit (len (y ), n_iterations , test_size ,
1181+ random_state = random_state )
11721182 for train , test in kf :
11731183 estimator = clone (clf )
11741184 fitted = estimator .fit (X [train ], y [train ])
0 commit comments