diff --git a/doc/modules/tree.rst b/doc/modules/tree.rst index 62d37d5a640f5..6ce4087b7ef49 100644 --- a/doc/modules/tree.rst +++ b/doc/modules/tree.rst @@ -55,10 +55,10 @@ Some advantages of decision trees are: The disadvantages of decision trees include: - Decision-tree learners can create over-complex trees that do not - generalise the data well. This is called overfitting. Mechanisms - such as pruning (not currently supported), setting the minimum - number of samples required at a leaf node or setting the maximum - depth of the tree are necessary to avoid this problem. + generalise the data well. This is called overfitting. Mechanisms such as + pruning, setting the minimum number of samples required at a leaf node or + setting the maximum depth of the tree are necessary to avoid this + problem. - Decision trees can be unstable because small variations in the data might result in a completely different tree being generated. @@ -183,6 +183,75 @@ instead of integer values:: * :ref:`example_tree_plot_tree_regression.py` + +.. _tree_pruning: + +Pruning +======= + +A common approach to get the best possible tree is to grow a huge tree (for +instance with ``max_depth=8``) and then prune it to an optimum size. As well as +providing a `prune` method for both :class:`DecisionTreeRegressor` and +:class:`DecisionTreeClassifier`, the function ``prune_path`` is useful +to find what the optimum size is for a tree. + +The prune method just takes as argument the number of leaves the fitted tree +should have (an int):: + + >>> from sklearn.datasets import load_boston + >>> from sklearn import tree + >>> boston = load_boston() + >>> clf = tree.DecisionTreeRegressor(max_depth=8) + >>> clf = clf.fit(boston.data, boston.target) + >>> clf = clf.prune(8) + +In order to find the optimal number of leaves we can use cross validated scores +on the data:: + + >>> from sklearn.datasets import load_boston + >>> from sklearn import tree + >>> boston = load_boston() + >>> clf = tree.DecisionTreeRegressor(max_depth=8) + >>> scores = tree.prune_path(clf, boston.data, boston.target, + ... max_n_leaves=20, n_iterations=10, random_state=0) + +In order to plot the scores one can use the following function:: + + def plot_pruned_path(scores, with_std=True): + """Plots the cross validated scores versus the number of leaves of trees""" + import matplotlib.pyplot as plt + means = np.array([np.mean(s) for s in scores]) + stds = np.array([np.std(s) for s in scores]) / np.sqrt(len(scores[1])) + + x = range(len(scores) + 1, 1, -1) + + plt.plot(x, means) + if with_std: + plt.plot(x, means + 2 * stds, lw=1, c='0.7') + plt.plot(x, means - 2 * stds, lw=1, c='0.7') + + plt.xlabel('Number of leaves') + plt.ylabel('Cross validated score') + + +For instance, using the Boston dataset we obtain such a graph + +.. figure:: ../auto_examples/tree/images/plot_prune_boston_1.png + :target: ../auto_examples/tree/plot_prune_boston.html + :align: center + :scale: 75 + +Here we see clearly that the optimum number of leaves is between 6 and 9. After +that additional leaves do not improve (or diminish) the score of the cross +validation. + +.. topic:: Examples: + + * :ref:`example_tree_plot_prune_boston.py` + * :ref:`example_tree_plot_overfitting_cv.py` + + + .. _tree_multioutput: Multi-output problems diff --git a/examples/tree/plot_overfitting_cv.py b/examples/tree/plot_overfitting_cv.py new file mode 100644 index 0000000000000..3669d142b4600 --- /dev/null +++ b/examples/tree/plot_overfitting_cv.py @@ -0,0 +1,72 @@ +""" +==================================================== +Comparison of cross validated score with overfitting +==================================================== + +These two plots compare the cross validated score of a the regression of +a simple function. We see that before the maximum value of 7 the regression is +far for the real function. On the other hand, for higher number of leaves we +clearly overfit. + +""" +print __doc__ + +import numpy as np +from sklearn import tree + + +def plot_pruned_path(scores, with_std=True): + """Plots the cross validated scores versus the number of leaves of trees""" + import matplotlib.pyplot as plt + means = np.array([np.mean(s) for s in scores]) + stds = np.array([np.std(s) for s in scores]) / np.sqrt(len(scores[1])) + + x = range(len(scores) + 1, 1, -1) + + plt.plot(x, means) + if with_std: + plt.plot(x, means + 2 * stds, lw=1, c='0.7') + plt.plot(x, means - 2 * stds, lw=1, c='0.7') + + plt.xlabel('Number of leaves') + plt.ylabel('Cross validated score') + + +# Create a random dataset +rng = np.random.RandomState(1) +X = np.sort(5 * rng.rand(80, 1), axis=0) +y = np.sin(X).ravel() +y[1::5] += 3 * (0.5 - rng.rand(16)) + + +clf = tree.DecisionTreeRegressor(max_depth=20) +scores = tree.prune_path(clf, X, y, max_n_leaves=20, + n_iterations=100, random_state=0) +plot_pruned_path(scores) + +clf = tree.DecisionTreeRegressor(max_depth=20, n_leaves=15) +clf.fit(X, y) +X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis] + +#Prepare the different pruned level +y_15 = clf.predict(X_test) + +clf = clf.prune(6) +y_7 = clf.predict(X_test) + +clf = clf.prune(2) +y_2 = clf.predict(X_test) + +# Plot the results +import pylab as pl + +pl.figure() +pl.scatter(X, y, c="k", label="data") +pl.plot(X_test, y_2, c="g", label="n_leaves=2", linewidth=2) +pl.plot(X_test, y_7, c="b", label="n_leaves=7", linewidth=2) +pl.plot(X_test, y_15, c="r", label="n_leaves=15", linewidth=2) +pl.xlabel("data") +pl.ylabel("target") +pl.title("Decision Tree Regression with levels of pruning") +pl.legend() +pl.show() diff --git a/examples/tree/plot_prune_boston.py b/examples/tree/plot_prune_boston.py new file mode 100644 index 0000000000000..3d61a126bee15 --- /dev/null +++ b/examples/tree/plot_prune_boston.py @@ -0,0 +1,39 @@ +""" +============================================ +Cross validated scores of the boston dataset +============================================ + +""" +print __doc__ + +import numpy as np +from sklearn.datasets import load_boston +from sklearn import tree + + +def plot_pruned_path(scores, with_std=True): + """Plots the cross validated scores versus the number of leaves of trees""" + import matplotlib.pyplot as plt + means = np.array([np.mean(s) for s in scores]) + stds = np.array([np.std(s) for s in scores]) / np.sqrt(len(scores[1])) + + x = range(len(scores) + 1, 1, -1) + + plt.plot(x, means) + if with_std: + plt.plot(x, means + 2 * stds, lw=1, c='0.7') + plt.plot(x, means - 2 * stds, lw=1, c='0.7') + + plt.xlabel('Number of leaves') + plt.ylabel('Cross validated score') + + +boston = load_boston() +clf = tree.DecisionTreeRegressor(max_depth=8) + +#Compute the cross validated scores +scores = tree.prune_path(clf, boston.data, boston.target, + max_n_leaves=20, n_iterations=10, + random_state=0) + +plot_pruned_path(scores) diff --git a/sklearn/tree/__init__.py b/sklearn/tree/__init__.py index 4becdfd010cd7..e412eb1ea33c4 100644 --- a/sklearn/tree/__init__.py +++ b/sklearn/tree/__init__.py @@ -8,3 +8,4 @@ from .tree import ExtraTreeClassifier from .tree import ExtraTreeRegressor from .tree import export_graphviz +from .tree import prune_path diff --git a/sklearn/tree/tree.py b/sklearn/tree/tree.py index 9dc7892c3e46d..7754cc66461d8 100644 --- a/sklearn/tree/tree.py +++ b/sklearn/tree/tree.py @@ -174,7 +174,7 @@ class Tree(object): LEAF = -1 UNDEFINED = -2 - def __init__(self, n_classes, n_features, n_outputs, capacity=3): + def __init__(self, n_classes, n_features, n_outputs=1, capacity=3): self.n_classes = n_classes self.n_features = n_features self.n_outputs = n_outputs @@ -266,6 +266,122 @@ def _add_leaf(self, parent, is_left_child, value, error, n_samples): return node_id + def _copy(self): + new_tree = Tree(self.n_classes, self.n_features, self.n_outputs) + new_tree.node_count = self.node_count + new_tree.children = self.children.copy() + new_tree.feature = self.feature.copy() + new_tree.threshold = self.threshold.copy() + new_tree.value = self.value.copy() + new_tree.best_error = self.best_error.copy() + new_tree.init_error = self.init_error.copy() + new_tree.n_samples = self.n_samples.copy() + + return new_tree + + @staticmethod + def _get_leaves(children): + """Lists the leaves from the children array of a tree object""" + return np.where(np.all(children == Tree.LEAF, axis=1))[0] + + @property + def leaves(self): + return self._get_leaves(self.children) + + def pruning_order(self, max_to_prune=None): + """Compute the order for which the tree should be pruned. + + The algorithm used is weakest link pruning. It removes first the nodes + that improve the tree the least. + + + Parameters + ---------- + max_to_prune : int, optional (default=all the nodes) + maximum number of nodes to prune + + Returns + ------- + nodes : numpy array + list of the nodes to remove to get to the optimal subtree. + + References + ---------- + + .. [1] J. Friedman and T. Hastie, "The elements of statistical + learning", 2001, section 9.2.1 + + """ + + def _get_terminal_nodes(children): + """Lists the nodes that only have leaves as children""" + leaves = self._get_leaves(children) + child_is_leaf = np.in1d(children, leaves).reshape(children.shape) + return np.where(np.all(child_is_leaf, axis=1))[0] + + def _next_to_prune(tree, children=None): + """Weakest link pruning for the subtree defined by children""" + + if children is None: + children = tree.children + + t_nodes = _get_terminal_nodes(children) + g_i = tree.init_error[t_nodes] - tree.best_error[t_nodes] + + return t_nodes[np.argmin(g_i)] + + if max_to_prune is None: + max_to_prune = self.node_count + + children = self.children.copy() + nodes = list() + + while True: + node = _next_to_prune(self, children) + nodes.append(node) + + if (len(nodes) == max_to_prune) or (node == 0): + return np.array(nodes) + + #Remove the subtree from the children array + children[children[node], :] = Tree.UNDEFINED + children[node, :] = Tree.LEAF + + def prune(self, n_leaves): + """Prunes the tree to obtain the optimal subtree with n_leaves leaves. + + + Parameters + ---------- + n_leaves : int + The final number of leaves the algorithm should bring + + Returns + ------- + tree : a Tree object + returns a new, pruned, tree + + References + ---------- + + .. [1] J. Friedman and T. Hastie, "The elements of statistical + learning", 2001, section 9.2.1 + + """ + + to_remove_count = self.node_count - len(self.leaves) - n_leaves + 1 + nodes_to_remove = self.pruning_order(to_remove_count) + + out_tree = self._copy() + + for node in nodes_to_remove: + #TODO: Add a Tree method to remove a branch of a tree + out_tree.children[out_tree.children[node], :] = Tree.UNDEFINED + out_tree.children[node, :] = Tree.LEAF + out_tree.node_count -= 2 + + return out_tree + def build(self, X, y, criterion, max_depth, min_samples_split, min_samples_leaf, min_density, max_features, random_state, find_split, sample_mask=None, X_argsorted=None): @@ -440,6 +556,7 @@ def __init__(self, criterion, max_depth, min_samples_split, min_samples_leaf, + n_leaves, min_density, max_features, compute_importances, @@ -448,6 +565,7 @@ def __init__(self, criterion, self.max_depth = max_depth self.min_samples_split = min_samples_split self.min_samples_leaf = min_samples_leaf + self.n_leaves = n_leaves self.min_density = min_density self.max_features = max_features self.compute_importances = compute_importances @@ -462,6 +580,22 @@ def __init__(self, criterion, self.tree_ = None self.feature_importances_ = None + def prune(self, n_leaves): + """Prunes the decision tree + + This method is necessary to avoid overfitting tree models. While broad + decision trees should be computed in the first place, pruning them + allows for smaller trees. + + Parameters + ---------- + n_leaves : int + the number of leaves of the pruned tree + + """ + self.tree_ = self.tree_.prune(n_leaves) + return self + def fit(self, X, y, sample_mask=None, X_argsorted=None): """Build a decision tree from the training set (X, y). @@ -565,6 +699,9 @@ def fit(self, X, y, sample_mask=None, X_argsorted=None): self.feature_importances_ = \ self.tree_.compute_feature_importances() + if self.n_leaves is not None: + self.prune(self.n_leaves) + return self def predict(self, X): @@ -634,6 +771,10 @@ class DecisionTreeClassifier(BaseDecisionTree, ClassifierMixin): min_samples_leaf : integer, optional (default=1) The minimum number of samples required to be at a leaf node. + n_leaves : integer, optional (default=None) + The number of leaves of the post-pruned tree. If None, no post-pruning + will be run. + min_density : float, optional (default=0.1) This parameter controls a trade-off in an optimization heuristic. It controls the minimum density of the `sample_mask` (i.e. the @@ -711,6 +852,7 @@ def __init__(self, criterion="gini", max_depth=None, min_samples_split=1, min_samples_leaf=1, + n_leaves=None, min_density=0.1, max_features=None, compute_importances=False, @@ -719,6 +861,7 @@ def __init__(self, criterion="gini", max_depth, min_samples_split, min_samples_leaf, + n_leaves, min_density, max_features, compute_importances, @@ -814,6 +957,10 @@ class DecisionTreeRegressor(BaseDecisionTree, RegressorMixin): min_samples_leaf : integer, optional (default=1) The minimum number of samples required to be at a leaf node. + n_leaves : integer, optional (default=None) + The number of leaves of the post-pruned tree. If None, no post-pruning + will be run. + min_density : float, optional (default=0.1) This parameter controls a trade-off in an optimization heuristic. It controls the minimum density of the `sample_mask` (i.e. the @@ -893,6 +1040,7 @@ def __init__(self, criterion="mse", max_depth=None, min_samples_split=1, min_samples_leaf=1, + n_leaves=None, min_density=0.1, max_features=None, compute_importances=False, @@ -901,6 +1049,7 @@ def __init__(self, criterion="mse", max_depth, min_samples_split, min_samples_leaf, + n_leaves, min_density, max_features, compute_importances, @@ -933,6 +1082,7 @@ def __init__(self, criterion="gini", max_depth=None, min_samples_split=1, min_samples_leaf=1, + n_leaves=None, min_density=0.1, max_features="auto", compute_importances=False, @@ -941,6 +1091,7 @@ def __init__(self, criterion="gini", max_depth, min_samples_split, min_samples_leaf, + n_leaves, min_density, max_features, compute_importances, @@ -979,6 +1130,7 @@ def __init__(self, criterion="mse", max_depth=None, min_samples_split=1, min_samples_leaf=1, + n_leaves=None, min_density=0.1, max_features="auto", compute_importances=False, @@ -987,9 +1139,75 @@ def __init__(self, criterion="mse", max_depth, min_samples_split, min_samples_leaf, + n_leaves, min_density, max_features, compute_importances, random_state) self.find_split_ = _tree._find_best_random_split + + +def prune_path(clf, X, y, max_n_leaves=10, n_iterations=10, + test_size=0.1, random_state=None): + """Cross validation of scores for different values of the decision tree. + + This function allows to test what the optimal size of the post-pruned + decision tree should be. It computes cross validated scores for different + size of the tree. + + Parameters + ---------- + clf: decision tree estimator object + The object to use to fit the data + + X: array-like of shape at least 2D + The data to fit. + + y: array-like + The target variable to try to predict. + + max_n_leaves : int, optional (default=10) + maximum number of leaves of the tree to prune + + n_iterations : int, optional (default=10) + Number of re-shuffling & splitting iterations. + + test_size : float (default=0.1) or int + If float, should be between 0.0 and 1.0 and represent the + proportion of the dataset to include in the test split. If + int, represents the absolute number of test samples. + + random_state : int or RandomState + Pseudo-random number generator state used for random sampling. + + Returns + ------- + scores : list of list of floats + The scores of the computed cross validated trees grouped by tree size. + scores[0] correspond to the values of trees of size max_n_leaves and + scores[-1] to the tree with just two leaves. + + """ + + from ..base import clone + from ..cross_validation import ShuffleSplit + + scores = list() + + kf = ShuffleSplit(len(y), n_iterations, test_size, + random_state=random_state) + for train, test in kf: + estimator = clone(clf) + fitted = estimator.fit(X[train], y[train]) + + loc_scores = list() + for i in range(max_n_leaves, 1, -1): + #We loop from the bigger values to the smaller ones in order to be + #able to compute the original tree once, and then make it smaller + fitted.prune(n_leaves=i) + loc_scores.append(fitted.score(X[test], y[test])) + + scores.append(loc_scores) + + return zip(*scores)