@@ -6,11 +6,9 @@ class DesicionTree():
6
6
def __init__ (self ):
7
7
pass
8
8
9
- def _calcShannonEnt (self , dataSet ): ## 计算数据集的熵
10
- numEntries = len (dataSet )
9
+ def _calcShannonEnt (self , classList ): ## 计算数据集的熵
11
10
classCounts = {}
12
- for data in dataSet :
13
- currentLabel = data [- 1 ]
11
+ for currentLabel in classList :
14
12
if currentLabel not in classCounts :
15
13
classCounts [currentLabel ] = 1
16
14
else :
@@ -24,37 +22,39 @@ def _calcShannonEnt(self, dataSet): ## 计算数据集的熵
24
22
'''
25
23
shannonEnt = 0.0
26
24
for key in classCounts :
27
- prob = classCounts [key ]/ float (numEntries )
25
+ prob = classCounts [key ]/ float (len ( classList ) )
28
26
shannonEnt -= prob * math .log (prob , 2 ) # log base 2
29
27
return shannonEnt
30
28
31
- def _splitDataSet (self , dataSet , axis , value ):
32
- retDataSet = []
33
- for data in dataSet :
29
+ def _splitDataSet (self , dataArr , classList , axis , value ):
30
+ retFeatData = []
31
+ retLabelData = []
32
+ for data , label in zip (dataArr , classList ):
34
33
# print data[axis]
35
34
if data [axis ] == value :
36
- reduceddata = data [:axis ]
37
- reduceddata .extend (data [axis + 1 :])
38
- retDataSet .append (reduceddata )
39
- return retDataSet
35
+ reducedFeat = data [:axis ]
36
+ reducedFeat .extend (data [axis + 1 :])
37
+ retFeatData .append (reducedFeat )
38
+ retLabelData .append (label )
39
+ return retFeatData , retLabelData
40
40
41
- def _chooseBestFeatureToSplit (self , dataSet ):
42
- numFeatures = len (dataSet [0 ])- 1 # 最后一列是类标签
43
- baseEntropy = self ._calcShannonEnt (dataSet )
41
+ def _chooseBestFeatureToSplit (self , dataArr , classList ):
42
+ baseEntropy = self ._calcShannonEnt (classList )
44
43
bestInfoGain = 0
45
44
bestFeature = - 1
45
+ numFeatures = len (dataArr [0 ])
46
46
for i in range (numFeatures ): # 依次迭代所有的特征
47
- featList = [data [i ] for data in dataSet ]
47
+ featList = [data [i ] for data in dataArr ]
48
48
values = set (featList )
49
49
'''
50
50
条件熵:sigma(pj*子数据集的熵)
51
51
'''
52
52
## 计算每个特征对数据集的条件熵
53
53
newEntropy = 0.0
54
54
for value in values :
55
- subDataSet = self ._splitDataSet (dataSet , i , value )
56
- prob = len (subDataSet )/ float (len (dataSet ))
57
- newEntropy += prob * self ._calcShannonEnt (subDataSet )
55
+ subDataArr , subClassList = self ._splitDataSet (dataArr , classList , i , value )
56
+ prob = len (subClassList )/ float (len (classList ))
57
+ newEntropy += prob * self ._calcShannonEnt (subClassList )
58
58
'''
59
59
信息增益 = 熵-条件熵
60
60
'''
@@ -66,33 +66,34 @@ def _chooseBestFeatureToSplit(self, dataSet):
66
66
67
67
def _majorityCnt (self , classList ):
68
68
classCount = {}
69
- for vote in classList :
70
- if vote not in classCount :
71
- classCount [vote ] = 1
69
+ for currentLabel in classList :
70
+ if currentLabel not in classCount :
71
+ classCount [currentLabel ] = 1
72
72
else :
73
- classCount [vote ] += 1
74
- # if vote not in classCount:
75
- # classCount[vote ] = 0
76
- # classCount[vote ] += 1
73
+ classCount [currentLabel ] += 1
74
+ # if currentLabel not in classCount:
75
+ # classCount[currentLabel ] = 0
76
+ # classCount[currentLabel ] += 1
77
77
sortedClassCount = sorted (classCount .items (), key = lambda xx :xx [1 ], reverse = True )
78
78
return sortedClassCount [0 ][0 ]
79
79
80
- def fit (self , dataSet , featLabels ):
81
- classList = [data [- 1 ] for data in dataSet ]
80
+ def fit (self , dataArr , classList , featLabels ):
82
81
if classList .count (classList [0 ]) == len (classList ):
83
82
return classList [0 ] # 所有的类标签都相同,则返回类标签
84
- if len (dataSet [0 ]) == 1 : # 所有的类标签不完全相同,但用完所有特征,则返回次数最多的类标签
83
+ if len (dataArr [0 ]) == 0 : # 所有的类标签不完全相同,但用完所有特征,则返回次数最多的类标签
85
84
return self ._majorityCnt (classList )
86
- bestFeat = self ._chooseBestFeatureToSplit (dataSet )
85
+ bestFeat = self ._chooseBestFeatureToSplit (dataArr , classList )
87
86
bestFeatLabel = featLabels [bestFeat ]
88
87
tree = {bestFeatLabel :{}}
89
88
featLabels_copy = featLabels [:] # 这样不会改变输入的featLabels
90
89
featLabels_copy .remove (bestFeatLabel )
91
- featList = [data [bestFeat ] for data in dataSet ]
90
+ featList = [data [bestFeat ] for data in dataArr ]
92
91
values = set (featList )
93
92
for value in values :
94
- subfeatLabels_copy = featLabels_copy [:] # 列表复制,非列表引用
95
- tree [bestFeatLabel ][value ] = self .fit (self ._splitDataSet (dataSet , bestFeat , value ), subfeatLabels_copy )
93
+ subFeatLabels_copy = featLabels_copy [:] # 列表复制,非列表引用
94
+ subDataArr = self ._splitDataSet (dataArr , classList , bestFeat , value )[0 ]
95
+ subClassList = self ._splitDataSet (dataArr , classList , bestFeat , value )[1 ]
96
+ tree [bestFeatLabel ][value ] = self .fit (subDataArr , subClassList , subFeatLabels_copy )
96
97
return tree
97
98
98
99
def predict (self , tree , featLabels , testVec ):
@@ -113,14 +114,19 @@ def loadDataSet():
113
114
[1 , 0 , 'no' ],
114
115
[0 , 1 , 'no' ],
115
116
[0 , 1 , 'no' ]]
117
+ featData = []
118
+ labelData = []
119
+ for data in dataSet :
120
+ featData .append (data [:- 1 ])
121
+ labelData .append (data [- 1 ])
116
122
featLabels = ['no surfacing' , 'flippers' ] # 特征标签
117
- return dataSet , featLabels
123
+ return featData , labelData , featLabels
118
124
119
125
if __name__ == '__main__' :
120
- myDataSet , myFeatLabels = loadDataSet ()
121
- print myDataSet , myFeatLabels
126
+ myFeatData , myLabelData , myFeatLabels = loadDataSet ()
127
+ print myFeatData , myLabelData , myFeatLabels
122
128
dt = DesicionTree ()
123
- myTree = dt .fit (myDataSet , myFeatLabels )
129
+ myTree = dt .fit (myFeatData , myLabelData , myFeatLabels )
124
130
print myTree
125
131
results = dt .predict (myTree , myFeatLabels , [1 , 1 ])
126
132
print results
0 commit comments