1+ #coding=utf-8
2+ '''
3+ '''
4+ from math import log
5+ import operator
6+
7+ def createDataSet ():
8+ dataSet = [[1 ,1 ,'yes' ],
9+ [1 ,1 ,'yes' ],
10+ [1 ,0 ,'no' ],
11+ [0 ,1 ,'no' ],
12+ [0 ,1 ,'no' ]]
13+ labels = ['no surfacing' ,'flippers' ] #分类的属性
14+ return dataSet ,labels
15+
16+ #计算给定数据的香农熵
17+ def calcShannonEnt (dataSet ):
18+ numEntries = len (dataSet )
19+ labelCounts = {}
20+ for featVec in dataSet :
21+ currentLabel = featVec [- 1 ] #获得标签
22+ #构造存放标签的字典
23+ if currentLabel not in labelCounts .keys ():
24+ labelCounts [currentLabel ]= 0
25+ labelCounts [currentLabel ]+= 1 #对应的标签数目+1
26+ #计算香农熵
27+ shannonEnt = 0.0
28+ for key in labelCounts :
29+ prob = float (labelCounts [key ])/ numEntries
30+ shannonEnt -= prob * log (prob ,2 )
31+ return shannonEnt
32+
33+ #划分数据集,三个参数为带划分的数据集,划分数据集的特征,特征的返回值
34+ def splitDataSet (dataSet ,axis ,value ):
35+ retDataSet = []
36+ for featVec in dataSet :
37+ if featVec [axis ] == value :
38+ #将相同数据集特征的抽取出来
39+ reducedFeatVec = featVec [:axis ]
40+ reducedFeatVec .extend (featVec [axis + 1 :])
41+ retDataSet .append (reducedFeatVec )
42+ return retDataSet #返回一个列表
43+
44+ #选择最好的数据集划分方式
45+ def chooseBestFeatureToSplit (dataSet ):
46+ numFeature = len (dataSet [0 ])- 1
47+ baseEntropy = calcShannonEnt (dataSet )
48+ bestInfoGain = 0.0
49+ beatFeature = - 1
50+ for i in range (numFeature ):
51+ featureList = [example [i ] for example in dataSet ] #获取第i个特征所有的可能取值
52+ uniqueVals = set (featureList ) #从列表中创建集合,得到不重复的所有可能取值ֵ
53+ newEntropy = 0.0
54+ for value in uniqueVals :
55+ subDataSet = splitDataSet (dataSet ,i ,value ) #以i为数据集特征,value为返回值,划分数据集
56+ prob = len (subDataSet )/ float (len (dataSet )) #数据集特征为i的所占的比例
57+ newEntropy += prob * calcShannonEnt (subDataSet ) #计算每种数据集的信息熵
58+ infoGain = baseEntropy - newEntropy
59+ #计算最好的信息增益,增益越大说明所占决策权越大
60+ if (infoGain > bestInfoGain ):
61+ bestInfoGain = infoGain
62+ bestFeature = i
63+ return bestFeature
64+
65+ #递归构建决策树
66+ def majorityCnt (classList ):
67+ classCount = {}
68+ for vote in classList :
69+ if vote not in classCount .keys ():
70+ classCount [vote ]= 0
71+ classCount [vote ]+= 1
72+ sortedClassCount = sorted (classCount .iteritems (),key = operator .itemgetter (1 ),reverse = True )#排序,True升序
73+ return sortedClassCount [0 ][0 ] #返回出现次数最多的
74+
75+ #创建树的函数代码
76+ def createTree (dataSet ,labels ):
77+ classList = [example [- 1 ] for example in dataSet ]
78+ if classList .count (classList [0 ])== len (classList ):#类别完全相同则停止划分
79+ return classList [0 ]
80+ if len (dataSet [0 ]) == 1 : #遍历完所有特征值时返回出现次数最多的
81+ return majorityCnt (classList )
82+ bestFeat = chooseBestFeatureToSplit (dataSet ) #选择最好的数据集划分方式
83+ bestFeatLabel = labels [bestFeat ] #得到对应的标签值
84+ myTree = {bestFeatLabel :{}}
85+ del (labels [bestFeat ]) #清空labels[bestFeat],在下一次使用时清零
86+ featValues = [example [bestFeat ] for example in dataSet ]
87+ uniqueVals = set (featValues )
88+ for value in uniqueVals :
89+ subLabels = labels [:]
90+ #递归调用创建决策树函数
91+ myTree [bestFeatLabel ][value ]= createTree (splitDataSet (dataSet ,bestFeat ,value ),subLabels )
92+ return myTree
93+
94+ if __name__ == "__main__" :
95+ dataSet ,labels = createDataSet ()
96+ print createTree (dataSet ,labels )
0 commit comments