• 鸢尾花数据集,特征为连续值数据的决策树的多分类


    1.导入工具

    import pandas as pd
    from sklearn import preprocessing
    from sklearn import tree
    from sklearn.datasets import load_iris

    2.导入鸢尾花数据集,探索数据集
    iris=load_iris()
    #iris是一个字典,包含了数据、标签、标签名、数据描述等信息。可以通过键来索引对应值。
    iris
    #查看iris字典里的所有键
    dir(iris)
    iris.data
    #150个数据,每个数据都有四个维度的特征,每个特征都是连续数值
    iris.data.shape
    #四个特征列名
    iris.feature_names
    #标签,0,1,2对应三种不同的鸢尾花
    iris.target
    #三种鸢尾花的名字
    iris.target_names
    鸢尾花数据集的描述说明信息
    print(iris.DESCR)

    3.构建决策树模型
    dir(iris)
    clf=tree.DecisionTreeClassifier(max_depth=4)
    clf=clf.fit(iris.data, iris.target)
    clf

    4.可视化决策树
    import pydotplus
    from IPython.display import Image,display
    dot_data=tree.export_graphviz(clf,
                                 out_file=None,
                                 feature_names=iris.feature_names,
                                 class_names=iris.target_names,
                                 filled=True,
                                 rounded=True
                                 )
    graph=pydotplus.graph_from_dot_data(dot_data)
    display(Image(graph.create_png()))


    5.对整个训练集做预测
    clf.predict(iris.data)

    6.对单个样本做预测
    #假设有一朵新的鸢尾花,四个特征分别为6.6cm,2.5cm,4.3cm,1,3cm。用训练好的决策树判断它属于哪一类鸢尾花。
    import numpy as np
    a1=np.array([6.6, 2.5, 4.3, 1.3])
    a1
    a1.shape
    a1.reshape(1,-1).shape
    clf.predict(a1.reshape(1,-1))
    #属于第二类鸢尾花。
    7.对多个样本做预测
    a1=iris.data[30]
    a2=iris.data[70]
    a3=iris.data[120]
    import numpy as np
    b=np.row_stack((a1,a2,a3))
    b
    clf.predict(b)
    import numpy as np
    import matplotlib.pyplot as plt
    %matplotlib.colors import ListedIormap
    from matplotlib.colors import ListedColormap 
    from sklearn import datasets
    from sklearn import tree
    iris=datasets.load_iris()
    x=iris.data[:,2:4]#取出花瓣的长和宽
    y=iris.target#取出标签
    #计算散点图的上下界
    x_min,x_max=x[:,0].min() -.5,  x[:,0].max()+.5
    y_min,y_max=x[:,1].min() -.5,  x[:,1].max()+.5
    #绘制边界
    camo=cmap_light=ListedColormap(['#AAAAFF','#AAFFAA','#FFAAAA'])
    h=.02
    xx,yy=np.meshgrid(np.arange(x_min,x_max,h),np.arange(y_min,y_max,h))
    clf=tree.DecisionTreeClassifier(max_depth=4)
    clf=clf.fit(x, y)
    Z=clf.predict(np.c_[xx.ravel(),yy.ravel()])
    Z=Z.reshape(xx.shape)
    plt.figure()
    plt.pcolormesh(xx,yy,Z,cmap=cmap_light)
    plt.scatter(x[:,0],x[:,1],c=y) 
    plt.xlim(xx.min(),xx.max())
    plt.ylim(yy.min(),yy.max())
    plt.show()
     

  • 相关阅读:
    PHP爬虫类的并发与多线程处理技巧
    ubuntu20.04.6wifi图标消失问题解决方案
    [ MSF使用实例 ] 利用MS12-020漏洞导致windows靶机蓝屏
    第三十八章 在 UNIX®、Linux 和 macOS 上使用 IRIS(三)
    【QML】使用 QtQuick2的ListView创建一个列表(一)
    前端周刊第三十期
    面试题之MyBatis缓存
    React Router 路由守卫
    深入浅出MySQL-03-【MySQL中的运算符】
    pom管理规范
  • 原文地址:https://blog.csdn.net/m0_57431551/article/details/127935759