11#-*- coding:utf-8 -*-
2+ import matplotlib .pyplot as plt
23import numpy as np
34
45def loadDataSet (fileName ):
@@ -21,6 +22,30 @@ def loadDataSet(fileName):
2122 dataMat .append (fltLine )
2223 return dataMat
2324
25+ def plotDataSet (filename ):
26+ """
27+ 函数说明:绘制数据集
28+ Parameters:
29+ filename - 文件名
30+ Returns:
31+ 无
32+ Website:
33+ http://www.cuijiahua.com/
34+ Modify:
35+ 2017-11-12
36+ """
37+ dataMat = loadDataSet (filename ) #加载数据集
38+ n = len (dataMat ) #数据个数
39+ xcord = []; ycord = [] #样本点
40+ for i in range (n ):
41+ xcord .append (dataMat [i ][0 ]); ycord .append (dataMat [i ][1 ]) #样本点
42+ fig = plt .figure ()
43+ ax = fig .add_subplot (111 ) #添加subplot
44+ ax .scatter (xcord , ycord , s = 20 , c = 'blue' ,alpha = .5 ) #绘制样本点
45+ plt .title ('DataSet' ) #绘制title
46+ plt .xlabel ('X' )
47+ plt .show ()
48+
2449def binSplitDataSet (dataSet , feature , value ):
2550 """
2651 函数说明:根据特征切分数据集合
@@ -145,34 +170,91 @@ def createTree(dataSet, leafType = regLeaf, errType = regErr, ops = (1, 4)):
145170 retTree = {}
146171 retTree ['spInd' ] = feat
147172 retTree ['spVal' ] = val
173+ #分成左数据集和右数据集
148174 lSet , rSet = binSplitDataSet (dataSet , feat , val )
175+ #创建左子树和右子树
149176 retTree ['left' ] = createTree (lSet , leafType , errType , ops )
150177 retTree ['right' ] = createTree (rSet , leafType , errType , ops )
151178 return retTree
152179
153- # def linearSolve(dataSet): #helper function used in two places
154- # m,n = shape(dataSet)
155- # X = mat(ones((m,n))); Y = mat(ones((m,1)))#create a copy of data with 1 in 0th postion
156- # X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]#and strip out Y
157- # xTx = X.T*X
158- # if linalg.det(xTx) == 0.0:
159- # raise NameError('This matrix is singular, cannot do inverse,\n\
160- # try increasing the second value of ops')
161- # ws = xTx.I * (X.T * Y)
162- # return ws,X,Y
180+ def isTree (obj ):
181+ """
182+ 函数说明:判断测试输入变量是否是一棵树
183+ Parameters:
184+ obj - 测试对象
185+ Returns:
186+ 是否是一棵树
187+ Website:
188+ http://www.cuijiahua.com/
189+ Modify:
190+ 2017-12-14
191+ """
192+ import types
193+ return (type (obj ).__name__ == 'dict' )
163194
164- # def modelLeaf(dataSet):#create linear model and return coeficients
165- # ws,X,Y = linearSolve(dataSet)
166- # return ws
195+ def getMean (tree ):
196+ """
197+ 函数说明:对树进行塌陷处理(即返回树平均值)
198+ Parameters:
199+ tree - 树
200+ Returns:
201+ 树的平均值
202+ Website:
203+ http://www.cuijiahua.com/
204+ Modify:
205+ 2017-12-14
206+ """
207+ if isTree (tree ['right' ]): tree ['right' ] = getMean (tree ['right' ])
208+ if isTree (tree ['left' ]): tree ['left' ] = getMean (tree ['left' ])
209+ return (tree ['left' ] + tree ['right' ]) / 2.0
167210
168- # def modelErr(dataSet):
169- # ws,X,Y = linearSolve(dataSet)
170- # yHat = X * ws
171- # return sum(power(Y - yHat,2))
211+ def prune (tree , testData ):
212+ """
213+ 函数说明:后剪枝
214+ Parameters:
215+ tree - 树
216+ test - 测试集
217+ Returns:
218+ 树的平均值
219+ Website:
220+ http://www.cuijiahua.com/
221+ Modify:
222+ 2017-12-14
223+ """
224+ #如果测试集为空,则对树进行塌陷处理
225+ if np .shape (testData )[0 ] == 0 : return getMean (tree )
226+ #如果有左子树或者右子树,则切分数据集
227+ if (isTree (tree ['right' ]) or isTree (tree ['left' ])):
228+ lSet , rSet = binSplitDataSet (testData , tree ['spInd' ], tree ['spVal' ])
229+ #处理左子树(剪枝)
230+ if isTree (tree ['left' ]): tree ['left' ] = prune (tree ['left' ], lSet )
231+ #处理右子树(剪枝)
232+ if isTree (tree ['right' ]): tree ['right' ] = prune (tree ['right' ], rSet )
233+ #如果当前结点的左右结点为叶结点
234+ if not isTree (tree ['left' ]) and not isTree (tree ['right' ]):
235+ lSet , rSet = binSplitDataSet (testData , tree ['spInd' ], tree ['spVal' ])
236+ #计算没有合并的误差
237+ errorNoMerge = np .sum (np .power (lSet [:,- 1 ] - tree ['left' ],2 )) + np .sum (np .power (rSet [:,- 1 ] - tree ['right' ],2 ))
238+ #计算合并的均值
239+ treeMean = (tree ['left' ] + tree ['right' ]) / 2.0
240+ #计算合并的误差
241+ errorMerge = np .sum (np .power (testData [:,- 1 ] - treeMean , 2 ))
242+ #如果合并的误差小于没有合并的误差,则合并
243+ if errorMerge < errorNoMerge :
244+ # print("merging")
245+ return treeMean
246+ else : return tree
247+ else : return tree
172248
173249if __name__ == '__main__' :
174- myDat = loadDataSet ('ex00.txt' )
175- myMat = np .mat (myDat )
176- feat , val = chooseBestSplit (myMat , regLeaf , regErr , (1 , 4 ))
177- print (feat )
178- print (val )
250+ print ('剪枝前:' )
251+ train_filename = 'ex2.txt'
252+ train_Data = loadDataSet (train_filename )
253+ train_Mat = np .mat (train_Data )
254+ tree = createTree (train_Mat )
255+ print (tree )
256+ print ('\n 剪枝后:' )
257+ test_filename = 'ex2test.txt'
258+ test_Data = loadDataSet (test_filename )
259+ test_Mat = np .mat (test_Data )
260+ print (prune (tree , test_Mat ))
0 commit comments