Skip to content

Commit 637ab82

Browse files
committed
added multi-ouput tree example
1 parent a08a910 commit 637ab82

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""
2+
===================================================================
3+
Decision Tree Regression
4+
===================================================================
5+
6+
1D multi-output regression with :ref:`decision trees <tree>`: the decision tree
7+
is used to predict simultaneously the noisy x and y observations of a circle given
8+
the underlying feature. As a result, it learns local linear regressions
9+
approximating the sine curve.
10+
11+
We can see that if the maximum depth of the tree (controlled by the
12+
`max_depth` parameter) is set too high, the decision trees learn too fine
13+
details of the training data and learn from the noise, i.e. they overfit.
14+
"""
15+
print __doc__
16+
17+
import numpy as np
18+
19+
# Create a random dataset
20+
rng = np.random.RandomState(1)
21+
X = np.sort(200 * rng.rand(100, 1) - 100, axis=0)
22+
y = np.array([np.pi * np.sin(X).ravel(), np.pi * np.cos(X).ravel()]).T
23+
y[::5,:] += 3 * (0.5 - rng.rand(20,2))
24+
25+
# Fit regression model
26+
from sklearn.tree import DecisionTreeRegressor
27+
28+
clf_1 = DecisionTreeRegressor(max_depth=2)
29+
clf_2 = DecisionTreeRegressor(max_depth=5)
30+
clf_3 = DecisionTreeRegressor(max_depth=10)
31+
clf_1.fit(X, y)
32+
clf_2.fit(X, y)
33+
clf_3.fit(X, y)
34+
35+
# Predict
36+
X_test = np.arange(-100.0, 100.0, 0.01)[:, np.newaxis]
37+
y_1 = clf_1.predict(X_test)
38+
y_2 = clf_2.predict(X_test)
39+
y_3 = clf_3.predict(X_test)
40+
41+
# Plot the results
42+
import pylab as pl
43+
44+
pl.figure()
45+
pl.scatter(y[:,0], y[:,1], c="k", label="data")
46+
pl.scatter(y_1[:,0], y_1[:,1], c="g", label="max_depth=2", linewidth=2)
47+
pl.scatter(y_2[:,0], y_2[:,1], c="r", label="max_depth=5", linewidth=2)
48+
pl.scatter(y_3[:,0], y_3[:,1], c="b", label="max_depth=10", linewidth=2)
49+
pl.xlabel("data")
50+
pl.ylabel("target")
51+
pl.title("Multi-output Decision Tree Regression")
52+
pl.legend()
53+
pl.show()

0 commit comments

Comments
 (0)