Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit f9fd2cd

Browse files
test
1 parent 0b2b686 commit f9fd2cd

File tree

1 file changed

+43
-37
lines changed

1 file changed

+43
-37
lines changed

‎DesicionTree/DesicionTreeTest.py

Lines changed: 43 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@ class DesicionTree():
66
def __init__(self):
77
pass
88

9-
def _calcShannonEnt(self, dataSet): ## 计算数据集的熵
10-
numEntries = len(dataSet)
9+
def _calcShannonEnt(self, classList): ## 计算数据集的熵
1110
classCounts = {}
12-
for data in dataSet:
13-
currentLabel = data[-1]
11+
for currentLabel in classList:
1412
if currentLabel not in classCounts:
1513
classCounts[currentLabel] = 1
1614
else:
@@ -24,37 +22,39 @@ def _calcShannonEnt(self, dataSet): ## 计算数据集的熵
2422
'''
2523
shannonEnt = 0.0
2624
for key in classCounts:
27-
prob = classCounts[key]/float(numEntries)
25+
prob = classCounts[key]/float(len(classList))
2826
shannonEnt -= prob*math.log(prob, 2) # log base 2
2927
return shannonEnt
3028

31-
def _splitDataSet(self, dataSet, axis, value):
32-
retDataSet = []
33-
for data in dataSet:
29+
def _splitDataSet(self, dataArr, classList, axis, value):
30+
retFeatData = []
31+
retLabelData = []
32+
for data, label in zip(dataArr, classList):
3433
# print data[axis]
3534
if data[axis] == value:
36-
reduceddata = data[:axis]
37-
reduceddata.extend(data[axis+1:])
38-
retDataSet.append(reduceddata)
39-
return retDataSet
35+
reducedFeat = data[:axis]
36+
reducedFeat.extend(data[axis+1:])
37+
retFeatData.append(reducedFeat)
38+
retLabelData.append(label)
39+
return retFeatData, retLabelData
4040

41-
def _chooseBestFeatureToSplit(self, dataSet):
42-
numFeatures = len(dataSet[0])-1 # 最后一列是类标签
43-
baseEntropy = self._calcShannonEnt(dataSet)
41+
def _chooseBestFeatureToSplit(self, dataArr, classList):
42+
baseEntropy = self._calcShannonEnt(classList)
4443
bestInfoGain = 0
4544
bestFeature = -1
45+
numFeatures = len(dataArr[0])
4646
for i in range(numFeatures): # 依次迭代所有的特征
47-
featList = [data[i] for data in dataSet]
47+
featList = [data[i] for data in dataArr]
4848
values = set(featList)
4949
'''
5050
条件熵:sigma(pj*子数据集的熵)
5151
'''
5252
## 计算每个特征对数据集的条件熵
5353
newEntropy = 0.0
5454
for value in values:
55-
subDataSet= self._splitDataSet(dataSet, i, value)
56-
prob = len(subDataSet)/float(len(dataSet))
57-
newEntropy += prob*self._calcShannonEnt(subDataSet)
55+
subDataArr, subClassList= self._splitDataSet(dataArr, classList, i, value)
56+
prob = len(subClassList)/float(len(classList))
57+
newEntropy += prob*self._calcShannonEnt(subClassList)
5858
'''
5959
信息增益 = 熵-条件熵
6060
'''
@@ -66,33 +66,34 @@ def _chooseBestFeatureToSplit(self, dataSet):
6666

6767
def _majorityCnt(self, classList):
6868
classCount = {}
69-
for vote in classList:
70-
if vote not in classCount:
71-
classCount[vote] = 1
69+
for currentLabel in classList:
70+
if currentLabel not in classCount:
71+
classCount[currentLabel] = 1
7272
else:
73-
classCount[vote] += 1
74-
# if vote not in classCount:
75-
# classCount[vote] = 0
76-
# classCount[vote] += 1
73+
classCount[currentLabel] += 1
74+
# if currentLabel not in classCount:
75+
# classCount[currentLabel] = 0
76+
# classCount[currentLabel] += 1
7777
sortedClassCount = sorted(classCount.items(), key=lambda xx:xx[1], reverse=True)
7878
return sortedClassCount[0][0]
7979

80-
def fit(self, dataSet, featLabels):
81-
classList = [data[-1] for data in dataSet]
80+
def fit(self, dataArr, classList, featLabels):
8281
if classList.count(classList[0]) == len(classList):
8382
return classList[0] # 所有的类标签都相同,则返回类标签
84-
if len(dataSet[0]) == 1: # 所有的类标签不完全相同,但用完所有特征,则返回次数最多的类标签
83+
if len(dataArr[0]) == 0: # 所有的类标签不完全相同,但用完所有特征,则返回次数最多的类标签
8584
return self._majorityCnt(classList)
86-
bestFeat = self._chooseBestFeatureToSplit(dataSet)
85+
bestFeat = self._chooseBestFeatureToSplit(dataArr, classList)
8786
bestFeatLabel = featLabels[bestFeat]
8887
tree = {bestFeatLabel:{}}
8988
featLabels_copy = featLabels[:] # 这样不会改变输入的featLabels
9089
featLabels_copy.remove(bestFeatLabel)
91-
featList = [data[bestFeat] for data in dataSet]
90+
featList = [data[bestFeat] for data in dataArr]
9291
values = set(featList)
9392
for value in values:
94-
subfeatLabels_copy = featLabels_copy[:] # 列表复制,非列表引用
95-
tree[bestFeatLabel][value] = self.fit(self._splitDataSet(dataSet, bestFeat, value), subfeatLabels_copy)
93+
subFeatLabels_copy = featLabels_copy[:] # 列表复制,非列表引用
94+
subDataArr = self._splitDataSet(dataArr, classList, bestFeat, value)[0]
95+
subClassList = self._splitDataSet(dataArr, classList, bestFeat, value)[1]
96+
tree[bestFeatLabel][value] = self.fit(subDataArr, subClassList, subFeatLabels_copy)
9697
return tree
9798

9899
def predict(self, tree, featLabels, testVec):
@@ -113,14 +114,19 @@ def loadDataSet():
113114
[1, 0, 'no'],
114115
[0, 1, 'no'],
115116
[0, 1, 'no']]
117+
featData = []
118+
labelData = []
119+
for data in dataSet:
120+
featData.append(data[:-1])
121+
labelData.append(data[-1])
116122
featLabels = ['no surfacing', 'flippers'] # 特征标签
117-
return dataSet, featLabels
123+
return featData, labelData, featLabels
118124

119125
if __name__ == '__main__':
120-
myDataSet, myFeatLabels = loadDataSet()
121-
print myDataSet, myFeatLabels
126+
myFeatData, myLabelData, myFeatLabels = loadDataSet()
127+
print myFeatData, myLabelData, myFeatLabels
122128
dt = DesicionTree()
123-
myTree = dt.fit(myDataSet, myFeatLabels)
129+
myTree = dt.fit(myFeatData, myLabelData, myFeatLabels)
124130
print myTree
125131
results = dt.predict(myTree, myFeatLabels, [1, 1])
126132
print results

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /