目录
下面简单看一下例子:
常规模块的导入以及图像可视化的设置:
- # Common imports
- import numpy as np
- import os
-
- # to make this notebook's output stable across runs
- np.random.seed(42)
-
- # To plot pretty figures
- %matplotlib inline
- import matplotlib as mpl
- import matplotlib.pyplot as plt
- mpl.rc('axes', labelsize=14)
- mpl.rc('xtick', labelsize=12)
- mpl.rc('ytick', labelsize=12)
- from sklearn.datasets import load_iris
- from sklearn.tree import DecisionTreeClassifier,export_graphviz
-
- iris = load_iris()
- X = iris.data[:, 2:] # petal length and width
- y = iris.target
-
- tree_clf = DecisionTreeClassifier(max_depth=2, random_state=42)
- tree_clf.fit(X, y)
-
- #可视化决策树
- #网站显示结构:http://webgraphviz.com/
- #http://dreampuf.github.io/GraphvizOnline/将dot文件内容复制该网站即可,等待一会出图
- export_graphviz(tree_clf,out_file="iris1_tree.dot")
-
-
默认路径下打开iris1_tree.dot文件:
- digraph Tree {
- node [shape=box, fontname="helvetica"] ;
- edge [fontname="helvetica"] ;
- 0 [label="X[0] <= 2.45\ngini = 0.667\nsamples = 150\nvalue = [50, 50, 50]"] ;
- 1 [label="gini = 0.0\nsamples = 50\nvalue = [50, 0, 0]"] ;
- 0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
- 2 [label="X[1] <= 1.75\ngini = 0.5\nsamples = 100\nvalue = [0, 50, 50]"] ;
- 0 -> 2 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
- 3 [label="gini = 0.168\nsamples = 54\nvalue = [0, 49, 5]"] ;
- 2 -> 3 ;
- 4 [label="gini = 0.043\nsamples = 46\nvalue = [0, 1, 45]"] ;
- 2 -> 4 ;
- }
具体可视化步骤已在本篇博文中讲述:
机器学习(18)——分类算法(补充)_WHJ226的博客-CSDN博客
简单步骤如下:首先打开该网站Graphviz Online ,最后将dot文件内容复制粘贴左侧代码区即可。
效果如下:(另外pycharm中的插件也可以实现决策树可视化,不过目前上述方法还没出现问题就未曾探索)