1+ """
2+ ===================================================================
3+ Decision Tree Regression
4+ ===================================================================
5+
6+ 1D multi-output regression with :ref:`decision trees <tree>`: the decision tree
7+ is used to predict simultaneously the noisy x and y observations of a circle given
8+ the underlying feature. As a result, it learns local linear regressions
9+ approximating the sine curve.
10+
11+ We can see that if the maximum depth of the tree (controlled by the
12+ `max_depth` parameter) is set too high, the decision trees learn too fine
13+ details of the training data and learn from the noise, i.e. they overfit.
14+ """
15+ print __doc__
16+
17+ import numpy as np
18+
19+ # Create a random dataset
20+ rng = np .random .RandomState (1 )
21+ X = np .sort (200 * rng .rand (100 , 1 ) - 100 , axis = 0 )
22+ y = np .array ([np .pi * np .sin (X ).ravel (), np .pi * np .cos (X ).ravel ()]).T
23+ y [::5 ,:] += 3 * (0.5 - rng .rand (20 ,2 ))
24+
25+ # Fit regression model
26+ from sklearn .tree import DecisionTreeRegressor
27+
28+ clf_1 = DecisionTreeRegressor (max_depth = 2 )
29+ clf_2 = DecisionTreeRegressor (max_depth = 5 )
30+ clf_3 = DecisionTreeRegressor (max_depth = 10 )
31+ clf_1 .fit (X , y )
32+ clf_2 .fit (X , y )
33+ clf_3 .fit (X , y )
34+
35+ # Predict
36+ X_test = np .arange (- 100.0 , 100.0 , 0.01 )[:, np .newaxis ]
37+ y_1 = clf_1 .predict (X_test )
38+ y_2 = clf_2 .predict (X_test )
39+ y_3 = clf_3 .predict (X_test )
40+
41+ # Plot the results
42+ import pylab as pl
43+
44+ pl .figure ()
45+ pl .scatter (y [:,0 ], y [:,1 ], c = "k" , label = "data" )
46+ pl .scatter (y_1 [:,0 ], y_1 [:,1 ], c = "g" , label = "max_depth=2" , linewidth = 2 )
47+ pl .scatter (y_2 [:,0 ], y_2 [:,1 ], c = "r" , label = "max_depth=5" , linewidth = 2 )
48+ pl .scatter (y_3 [:,0 ], y_3 [:,1 ], c = "b" , label = "max_depth=10" , linewidth = 2 )
49+ pl .xlabel ("data" )
50+ pl .ylabel ("target" )
51+ pl .title ("Multi-output Decision Tree Regression" )
52+ pl .legend ()
53+ pl .show ()
0 commit comments