文章目录

利用matplotlib进行决策树可视化
代码来源于《机器学习实战》,具体讲解可参见这本书。
部分代码解释可见这里
讲的比较清楚。

treePlot.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import matplotlib.pyplot as plt
from pylab import *
import treesID3 as decTree
mpl.rcParams['font.sans-serif'] = ['SimHei']

decNode = dict(boxstyle="sawtooth",fc='0.8')
leafNode = dict(boxstyle="round4",fc='0.8')
arrow_args = dict(arrowstyle="<-")

def plotNode(nodeTxt, centerPt, fatherPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=fatherPt, xycoords='axes fraction', xytext=centerPt, textcoords = 'axes fraction', va='center', ha='center', bbox=nodeType, arrowprops=arrow_args)

def getNumLeafs(Tree):
num = 0
root = list(Tree.keys())[0]
firstGen = Tree[root]
for key in firstGen.keys():
if type(firstGen[key]) == type({}):
num += getNumLeafs(firstGen[key])
else:
num += 1
return num

def DepthofTree(Tree):
maxdepth = 0
root = list(Tree.keys())[0]
firstGen = Tree[root]
for key in firstGen.keys():
if type(firstGen[key]) == type({}):
depth = 1 + DepthofTree(firstGen[key])
else:
depth = 1
if depth > maxdepth:
maxdepth = depth
return maxdepth

def plotMidText(nowPt, fatherPt, txt):
xMid = (fatherPt[0]-nowPt[0]) / 2.0 + nowPt[0]
yMid = (fatherPt[1]-nowPt[1]) / 2.0 + nowPt[1]
createPlot.ax1.text(xMid,yMid,txt)

def plotTree(Tree, fatherPt, nodeTxt):
numLeafs = getNumLeafs(Tree)
depth = DepthofTree(Tree)
root = list(Tree.keys())[0]
nowPt = (plotTree.xoff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yoff)
plotMidText(nowPt,fatherPt,nodeTxt)
plotNode(root, nowPt, fatherPt, decNode)
firstGen = Tree[root]
plotTree.yoff = plotTree.yoff - 1.0/plotTree.totalD
for key in firstGen.keys():
if type(firstGen[key]) == type({}):
plotTree(firstGen[key], nowPt, str(key))
else:
plotTree.xoff = plotTree.xoff + 1.0/plotTree.totalW
plotNode(firstGen[key], (plotTree.xoff, plotTree.yoff), nowPt, leafNode)
plotMidText((plotTree.xoff, plotTree.yoff), nowPt, str(key))
plotTree.yoff = plotTree.yoff + 1.0/plotTree.totalD

def createPlot(Tree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[],yticks=[])
createPlot.ax1 = plt.subplot(111,frameon=False,**axprops)
plotTree.totalW = float(getNumLeafs(Tree))
plotTree.totalD = float(DepthofTree(Tree))
plotTree.xoff = -1.0 / 2.0 / plotTree.totalW # 1分成2倍叶子数那么多份
plotTree.yoff = 1.0
plotTree(Tree, (0.5,1.0), '')
plt.show()

if __name__ == '__main__':
filename = "D:\\MLinAction\\Data\\西瓜2.0.txt"
DataSet,featname = decTree.filetoDataSet(filename)
Tree = decTree.createDecisionTree(DataSet,featname)
print(Tree)
createPlot(Tree)
机器学习 | Mac.Learning