Skip to content

Commit aa853e7

Browse files
committed
树回归
1 parent 01bd8e2 commit aa853e7

File tree

1 file changed

+104
-22
lines changed

1 file changed

+104
-22
lines changed

Regression Trees/regTrees.py

Lines changed: 104 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#-*- coding:utf-8 -*-
2+
import matplotlib.pyplot as plt
23
import numpy as np
34

45
def loadDataSet(fileName):
@@ -21,6 +22,30 @@ def loadDataSet(fileName):
2122
dataMat.append(fltLine)
2223
return dataMat
2324

25+
def plotDataSet(filename):
26+
"""
27+
函数说明:绘制数据集
28+
Parameters:
29+
filename - 文件名
30+
Returns:
31+
32+
Website:
33+
http://www.cuijiahua.com/
34+
Modify:
35+
2017-11-12
36+
"""
37+
dataMat = loadDataSet(filename) #加载数据集
38+
n = len(dataMat) #数据个数
39+
xcord = []; ycord = [] #样本点
40+
for i in range(n):
41+
xcord.append(dataMat[i][0]); ycord.append(dataMat[i][1]) #样本点
42+
fig = plt.figure()
43+
ax = fig.add_subplot(111) #添加subplot
44+
ax.scatter(xcord, ycord, s = 20, c = 'blue',alpha = .5) #绘制样本点
45+
plt.title('DataSet') #绘制title
46+
plt.xlabel('X')
47+
plt.show()
48+
2449
def binSplitDataSet(dataSet, feature, value):
2550
"""
2651
函数说明:根据特征切分数据集合
@@ -145,34 +170,91 @@ def createTree(dataSet, leafType = regLeaf, errType = regErr, ops = (1, 4)):
145170
retTree = {}
146171
retTree['spInd'] = feat
147172
retTree['spVal'] = val
173+
#分成左数据集和右数据集
148174
lSet, rSet = binSplitDataSet(dataSet, feat, val)
175+
#创建左子树和右子树
149176
retTree['left'] = createTree(lSet, leafType, errType, ops)
150177
retTree['right'] = createTree(rSet, leafType, errType, ops)
151178
return retTree
152179

153-
# def linearSolve(dataSet): #helper function used in two places
154-
# m,n = shape(dataSet)
155-
# X = mat(ones((m,n))); Y = mat(ones((m,1)))#create a copy of data with 1 in 0th postion
156-
# X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]#and strip out Y
157-
# xTx = X.T*X
158-
# if linalg.det(xTx) == 0.0:
159-
# raise NameError('This matrix is singular, cannot do inverse,\n\
160-
# try increasing the second value of ops')
161-
# ws = xTx.I * (X.T * Y)
162-
# return ws,X,Y
180+
def isTree(obj):
181+
"""
182+
函数说明:判断测试输入变量是否是一棵树
183+
Parameters:
184+
obj - 测试对象
185+
Returns:
186+
是否是一棵树
187+
Website:
188+
http://www.cuijiahua.com/
189+
Modify:
190+
2017-12-14
191+
"""
192+
import types
193+
return (type(obj).__name__ == 'dict')
163194

164-
# def modelLeaf(dataSet):#create linear model and return coeficients
165-
# ws,X,Y = linearSolve(dataSet)
166-
# return ws
195+
def getMean(tree):
196+
"""
197+
函数说明:对树进行塌陷处理(即返回树平均值)
198+
Parameters:
199+
tree - 树
200+
Returns:
201+
树的平均值
202+
Website:
203+
http://www.cuijiahua.com/
204+
Modify:
205+
2017-12-14
206+
"""
207+
if isTree(tree['right']): tree['right'] = getMean(tree['right'])
208+
if isTree(tree['left']): tree['left'] = getMean(tree['left'])
209+
return (tree['left'] + tree['right']) / 2.0
167210

168-
# def modelErr(dataSet):
169-
# ws,X,Y = linearSolve(dataSet)
170-
# yHat = X * ws
171-
# return sum(power(Y - yHat,2))
211+
def prune(tree, testData):
212+
"""
213+
函数说明:后剪枝
214+
Parameters:
215+
tree - 树
216+
test - 测试集
217+
Returns:
218+
树的平均值
219+
Website:
220+
http://www.cuijiahua.com/
221+
Modify:
222+
2017-12-14
223+
"""
224+
#如果测试集为空,则对树进行塌陷处理
225+
if np.shape(testData)[0] == 0: return getMean(tree)
226+
#如果有左子树或者右子树,则切分数据集
227+
if (isTree(tree['right']) or isTree(tree['left'])):
228+
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
229+
#处理左子树(剪枝)
230+
if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
231+
#处理右子树(剪枝)
232+
if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet)
233+
#如果当前结点的左右结点为叶结点
234+
if not isTree(tree['left']) and not isTree(tree['right']):
235+
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
236+
#计算没有合并的误差
237+
errorNoMerge = np.sum(np.power(lSet[:,-1] - tree['left'],2)) + np.sum(np.power(rSet[:,-1] - tree['right'],2))
238+
#计算合并的均值
239+
treeMean = (tree['left'] + tree['right']) / 2.0
240+
#计算合并的误差
241+
errorMerge = np.sum(np.power(testData[:,-1] - treeMean, 2))
242+
#如果合并的误差小于没有合并的误差,则合并
243+
if errorMerge < errorNoMerge:
244+
# print("merging")
245+
return treeMean
246+
else: return tree
247+
else: return tree
172248

173249
if __name__ == '__main__':
174-
myDat = loadDataSet('ex00.txt')
175-
myMat = np.mat(myDat)
176-
feat, val = chooseBestSplit(myMat, regLeaf, regErr, (1, 4))
177-
print(feat)
178-
print(val)
250+
print('剪枝前:')
251+
train_filename = 'ex2.txt'
252+
train_Data = loadDataSet(train_filename)
253+
train_Mat = np.mat(train_Data)
254+
tree = createTree(train_Mat)
255+
print(tree)
256+
print('\n剪枝后:')
257+
test_filename = 'ex2test.txt'
258+
test_Data = loadDataSet(test_filename)
259+
test_Mat = np.mat(test_Data)
260+
print(prune(tree, test_Mat))

0 commit comments

Comments
 (0)