Skip to content

Commit 45adf40

Browse files
author
Steve Genoud
committed
Use shuffle and split to cross validate
1 parent 89af373 commit 45adf40

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

sklearn/tree/tree.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ def prune(self, n_leaves):
371371
372372
"""
373373

374-
to_remove_count = self.node_count - len(self.leaves) - n_leaves
374+
to_remove_count = self.node_count - len(self.leaves) - n_leaves + 1
375375
nodes_to_remove = self.pruning_order(to_remove_count)
376376

377377
out_tree = self._copy()
@@ -1130,7 +1130,8 @@ def __init__(self, criterion="mse",
11301130
self.find_split_ = _tree._find_best_random_split
11311131

11321132

1133-
def cv_scores_vs_n_leaves(clf, X, y, max_n_leaves=10, cv=10):
1133+
def cv_scores_vs_n_leaves(clf, X, y, max_n_leaves=10, n_iterations=10,
1134+
test_size=0.1, random_state=None):
11341135
"""Cross validation of scores for different values of the decision tree.
11351136
11361137
This function allows to test what the optimal size of the decision tree
@@ -1151,8 +1152,16 @@ def cv_scores_vs_n_leaves(clf, X, y, max_n_leaves=10, cv=10):
11511152
max_n_leaves : int, optional (default=10)
11521153
maximum number of leaves of the tree to prune
11531154
1154-
cv : int, optional (default=10)
1155-
Size of the KFold cross validation generator
1155+
n_iterations : int, optional (default=10)
1156+
Number of re-shuffling & splitting iterations.
1157+
1158+
test_size : float (default=0.1) or int
1159+
If float, should be between 0.0 and 1.0 and represent the
1160+
proportion of the dataset to include in the test split. If
1161+
int, represents the absolute number of test samples.
1162+
1163+
random_state : int or RandomState
1164+
Pseudo-random number generator state used for random sampling.
11561165
11571166
Returns
11581167
-------
@@ -1164,11 +1173,12 @@ def cv_scores_vs_n_leaves(clf, X, y, max_n_leaves=10, cv=10):
11641173
"""
11651174

11661175
from ..base import clone
1167-
from ..cross_validation import KFold
1176+
from ..cross_validation import ShuffleSplit
11681177

11691178
scores = list()
11701179

1171-
kf = KFold(len(y), cv)
1180+
kf = ShuffleSplit(len(y), n_iterations, test_size,
1181+
random_state=random_state)
11721182
for train, test in kf:
11731183
estimator = clone(clf)
11741184
fitted = estimator.fit(X[train], y[train])

0 commit comments

Comments
 (0)