1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
| import matplotlib.pyplot as plt from pylab import * import treesID3 as decTree mpl.rcParams['font.sans-serif'] = ['SimHei']
decNode = dict(boxstyle="sawtooth",fc='0.8') leafNode = dict(boxstyle="round4",fc='0.8') arrow_args = dict(arrowstyle="<-")
def plotNode(nodeTxt, centerPt, fatherPt, nodeType): createPlot.ax1.annotate(nodeTxt, xy=fatherPt, xycoords='axes fraction', xytext=centerPt, textcoords = 'axes fraction', va='center', ha='center', bbox=nodeType, arrowprops=arrow_args)
def getNumLeafs(Tree): num = 0 root = list(Tree.keys())[0] firstGen = Tree[root] for key in firstGen.keys(): if type(firstGen[key]) == type({}): num += getNumLeafs(firstGen[key]) else: num += 1 return num
def DepthofTree(Tree): maxdepth = 0 root = list(Tree.keys())[0] firstGen = Tree[root] for key in firstGen.keys(): if type(firstGen[key]) == type({}): depth = 1 + DepthofTree(firstGen[key]) else: depth = 1 if depth > maxdepth: maxdepth = depth return maxdepth
def plotMidText(nowPt, fatherPt, txt): xMid = (fatherPt[0]-nowPt[0]) / 2.0 + nowPt[0] yMid = (fatherPt[1]-nowPt[1]) / 2.0 + nowPt[1] createPlot.ax1.text(xMid,yMid,txt)
def plotTree(Tree, fatherPt, nodeTxt): numLeafs = getNumLeafs(Tree) depth = DepthofTree(Tree) root = list(Tree.keys())[0] nowPt = (plotTree.xoff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yoff) plotMidText(nowPt,fatherPt,nodeTxt) plotNode(root, nowPt, fatherPt, decNode) firstGen = Tree[root] plotTree.yoff = plotTree.yoff - 1.0/plotTree.totalD for key in firstGen.keys(): if type(firstGen[key]) == type({}): plotTree(firstGen[key], nowPt, str(key)) else: plotTree.xoff = plotTree.xoff + 1.0/plotTree.totalW plotNode(firstGen[key], (plotTree.xoff, plotTree.yoff), nowPt, leafNode) plotMidText((plotTree.xoff, plotTree.yoff), nowPt, str(key)) plotTree.yoff = plotTree.yoff + 1.0/plotTree.totalD
def createPlot(Tree): fig = plt.figure(1, facecolor='white') fig.clf() axprops = dict(xticks=[],yticks=[]) createPlot.ax1 = plt.subplot(111,frameon=False,**axprops) plotTree.totalW = float(getNumLeafs(Tree)) plotTree.totalD = float(DepthofTree(Tree)) plotTree.xoff = -1.0 / 2.0 / plotTree.totalW plotTree.yoff = 1.0 plotTree(Tree, (0.5,1.0), '') plt.show()
if __name__ == '__main__': filename = "D:\\MLinAction\\Data\\西瓜2.0.txt" DataSet,featname = decTree.filetoDataSet(filename) Tree = decTree.createDecisionTree(DataSet,featname) print(Tree) createPlot(Tree)
|