Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit b70f532

Browse files
决策树代码
1 parent 57a4278 commit b70f532

File tree

1 file changed

+153
-0
lines changed

1 file changed

+153
-0
lines changed

‎day-123/dt.py‎

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
from math import log
2+
import operator
3+
4+
5+
def createDataSet():
6+
'''
7+
创建数据集
8+
'''
9+
10+
dataSet = [[1, 1, 0, 'y'],
11+
[1, 1, 0, 'y'],
12+
[1, 0, 0, 'n'],
13+
[0, 1, 0, 'n'],
14+
[0, 0, 1, 'n'],
15+
[1, 0, 1, 'n'],
16+
[1, 1, 1, 'n']]
17+
labels = ['Salary', 'Time', 'Bank flow']
18+
return dataSet,labels
19+
20+
21+
def calcEntropy(dataSet):
22+
'''
23+
计算熵
24+
:param dataSet: 数据集
25+
:return: 熵值
26+
'''
27+
28+
numEntries = len(dataSet)
29+
labelCounts = {}
30+
for line in dataSet:
31+
currentLabel = line[-1]
32+
if currentLabel not in labelCounts.keys():
33+
labelCounts[currentLabel] = 0
34+
labelCounts[currentLabel] += 1
35+
entropy = 0.0
36+
for key in labelCounts:
37+
prob = float(labelCounts[key]) / numEntries
38+
entropy -= prob * log(prob, 2)
39+
return entropy
40+
41+
42+
def splitDataSet(dataSet,axis,value):
43+
'''
44+
划分数据集
45+
:param dataSet: 按照给定特征划分数据集
46+
:param axis: 划分数据集的特征
47+
:param value: 需要返回的特征的值
48+
:return: 经验熵
49+
'''
50+
retDataSet=[]
51+
for featVec in dataSet:
52+
if featVec[axis]==value:
53+
reducedFeatVec=featVec[:axis]
54+
reducedFeatVec.extend(featVec[axis+1:])
55+
retDataSet.append(reducedFeatVec)
56+
return retDataSet
57+
58+
59+
60+
61+
def chooseBestFeatureToSplit(dataSet):
62+
'''
63+
计算数据集的熵
64+
:param dataSet: 数据集
65+
:return: 最优的特征值的索引
66+
'''
67+
68+
# 特征个数
69+
numFeatures = len(dataSet[0]) - 1
70+
# 数据集的熵
71+
baseEntropy = calcEntropy(dataSet)
72+
# 最优信息增益
73+
bestInfoGain = 0.0
74+
# 最优特征的索引值
75+
bestFeature = -1
76+
77+
for i in range(numFeatures):
78+
# 获取数据集的第 i 个所有特征
79+
featList = [example[i] for example in dataSet]
80+
#创建 set集合{},元素不可重复
81+
uniqueVals = set(featList)
82+
# 经验条件熵
83+
newEntropy = 0.0
84+
#计算信息增益
85+
for value in uniqueVals:
86+
# 数据集划分后的子集
87+
subDataSet = splitDataSet(dataSet, i, value)
88+
#计算子集的概率
89+
prob = len(subDataSet) / float(len(dataSet))
90+
#根据公式计算经验条件熵
91+
newEntropy += prob * calcEntropy((subDataSet))
92+
#信息增益
93+
infoGain = baseEntropy - newEntropy
94+
#打印每个特征的信息增益
95+
print("第%d个特征属性的信息增益为%.3f" % (i, infoGain))
96+
97+
if (infoGain > bestInfoGain):
98+
bestInfoGain = infoGain
99+
bestFeature = i
100+
return bestFeature
101+
102+
103+
def majorityCnt(classList):
104+
'''
105+
类别数多的类别
106+
:param classList: 类别
107+
:return: 返回类别数多的类别
108+
'''
109+
classCount={}
110+
for vote in classList:
111+
if vote not in classCount.keys(): classCount[vote] = 0
112+
classCount[vote] += 1
113+
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
114+
return sortedClassCount[0][0]
115+
116+
def createTree(dataSet,labels):
117+
'''
118+
构建决策树
119+
:param dataSet: 数据集样本
120+
:param labels: 特征属性
121+
:return: 决策树
122+
'''
123+
124+
# 决策类别
125+
classList = [example[-1] for example in dataSet]
126+
# 类别完全相同停止继续划分
127+
if classList.count(classList[0]) == len(classList):
128+
return classList[0]
129+
# 返回出现次数最多的类别
130+
if len(dataSet[0]) == 1:
131+
return majorityCnt(classList)
132+
# 返回最优的特征属性
133+
bestFeature = chooseBestFeatureToSplit(dataSet)
134+
bestFeatLabel = labels[bestFeature]
135+
myTree = {bestFeatLabel:{}}
136+
del(labels[bestFeature])
137+
# 最优特征值
138+
featureValues = [example[bestFeature] for example in dataSet]
139+
uniqueVals = set(featureValues)
140+
for value in uniqueVals:
141+
subLabels = labels[:]
142+
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeature, value), subLabels)
143+
return myTree
144+
145+
146+
147+
mydata,labels = createDataSet()
148+
149+
# entropy = splitDataSet(mydata,0,1 )
150+
151+
# print("最优的索引值为:", str(chooseBestFeatureToSplit(mydata)))
152+
153+
print(createTree(mydata, labels))

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /