• 计算机视觉——python在一张图中画多条ROC线


    在验证分类算法的好坏时,经常需要用到AUC曲线,而在做不同分类模型的对比实验时,需要将不同模型的AUC曲线绘制到一张图里。

    1. 小型分类模型对比,可以直接调用的

    用一样的数据集做示例,简单地直接分别得到每个分类模型预测的结果。

    from sklearn.datasets import load_breast_cancer
    from sklearn import metrics
    from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier
    from sklearn.model_selection import train_test_split
    import pylab as plt
    import warnings; warnings.filterwarnings('ignore')
    
    dataset = load_breast_cancer()
    data = dataset.data
    target = dataset.target
    X_train, X_test, y_train, y_test = train_test_split(data,target,test_size=0.2)
    # 模型1调用sklearn中的RandomForestClassifier
    rf1 = RandomForestClassifier(n_estimators=5)
    rf1.fit(X_train, y_train)
    pred1 = rf1.predict_proba(X_test)[:,1]
    # 模型2调用sklearn中的ExtraTreesClassifier
    rf2 = ExtraTreesClassifier(n_estimators=5)
    rf2.fit(X_train, y_train)
    pred2 = rf2.predict_proba(X_test)[:,1]
    
    # 画图部分
    fpr1, tpr1, threshold1 = metrics.roc_curve(y_test, pred1)       #  
    roc_auc1 = metrics.auc(fpr1, tpr1)
    
    fpr2, tpr2, threshold2 = metrics.roc_curve(y_test, pred2)       #  
    roc_auc2 = metrics.auc(fpr2, tpr2)
    
    plt.figure(figsize=(6,6))
    plt.title('Validation ROC')
    plt.plot(fpr1, tpr1, 'b', label = 'RandomForestClassifier AUC = %0.3f' % roc_auc1)
    plt.plot(fpr2, tpr2, 'b', label = 'ExtraTreesClassifier AUC = %0.3f' % roc_auc2)
    plt.legend(loc = 'lower right')
    plt.plot([0, 1], [0, 1],'r--')
    plt.xlim([0, 1])
    plt.ylim([0, 1])
    plt.ylabel('True Positive Rate')
    plt.xlabel('False Positive Rate')
    plt.savefig("filename.png")
    plt.show()
    plt.close()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40

    在这里插入图片描述

    2. 大型的CNN模型,无法直接得到结果。

    2.1 先分别运行每个分类模型,将预测的结果存入csv文件中。

    from sklearn.datasets import load_breast_cancer
    from sklearn import metrics
    from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier
    from sklearn.model_selection import train_test_split
    import pylab as plt
    import warnings; warnings.filterwarnings('ignore')
    import pandas as pd
    import numpy as np
    from sklearn.linear_model import LinearRegression   #线性回归
    import csv
    
    dataset = load_breast_cancer()
    data = dataset.data
    target = dataset.target
    X_train, X_test, y_train, y_test = train_test_split(data,target,test_size=0.2)
    rf1 = RandomForestClassifier(n_estimators=5)
    rf1.fit(X_train, y_train)
    pred1 = rf1.predict_proba(X_test)[:,1]
    
    dataframe = pd.DataFrame({'label':y_test,'pred':pred1})
    dataframe.to_csv("test1.csv",index=False,sep=',')
    
    rf2 = ExtraTreesClassifier(n_estimators=5)
    rf2.fit(X_train, y_train)
    pred2 = rf2.predict_proba(X_test)[:,1]
    
    dataframe = pd.DataFrame({'label':y_test,'pred':pred2})
    dataframe.to_csv("test2.csv",index=False,sep=',')
    
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30

    2.2 从csv文件读取每个模型的预测结果,绘制AUC曲线

    import pandas as pd
    import seaborn as sns
    import matplotlib.pyplot as plt
    from sklearn.metrics import roc_curve, auc
    
    def Draw_ROC(file1,file2):
        '''这里注意读取csv的编码方式,
        如果csv里有中文,在windows系统上可以直接用encoding='ANSI',
        但是到了Mac或者Linux系统上会报错:`LookupError: unknown encoding: ansi`。
        解决方法:
        1. 可以改成encoding='gbk';
        2. 或者把csv文件里的列名改成英文,就不用选择encoding的方式了。
        '''
        data1=pd.read_csv(file1, encoding='ANSI')
        data1=pd.DataFrame(data1)
        data2=pd.read_csv(file2, encoding='ANSI')
        data2=pd.DataFrame(data2)
        print(list(data1['label']), list(data1['pred']))
        print(list(data2['label']), list(data2['pred']))
    
        fpr_CSNN,tpr_CSNN,thresholds=roc_curve(list(data1['label']),
                                               list(data1['pred']))
        roc_auc_CSSSNN=auc(fpr_CSNN,tpr_CSNN)
    
        fpr_NN,tpr_NN,thresholds=roc_curve(list(data2['label']),
                                           list(data2['pred']))
        roc_auc_DL=auc(fpr_NN,tpr_NN)
    
        font = {'family': 'Times New Roman',
                'size': 12,
                }
        '''这里很多电脑上也许默认是'DejaVu Sans'格式,但是在写论文时,
        往往需要'Times New Roman'格式,可以参考[这篇教程](https://blog.csdn.net/weixin_43543177/article/details/109723328)
        '''
        sns.set(font_scale=1.2)
        plt.rc('font',family='Times New Roman')
    
        plt.plot(fpr_NN,tpr_NN,'purple',label='NN_AUC = %0.2f'% roc_auc_DL)
        plt.plot(fpr_CSNN,tpr_CSNN,'blue',label='CSNN_AUC = %0.2f'% roc_auc_CSSSNN)
        plt.legend(loc='lower right',fontsize = 12)
        plt.plot([0,1],[0,1],'r--')
        plt.ylabel('True Positive Rate',fontsize = 14)
        plt.xlabel('Flase Positive Rate',fontsize = 14)
        plt.show()
    
    if __name__=="__main__":
        Draw_ROC('./test1.csv',
                 './test2.csv')
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49

    在这里插入图片描述

  • 相关阅读:
    [React] react-hooks如何使用
    java 企业工程管理系统软件源码 自主研发 工程行业适用
    ActiveMQ-架构设计
    ch2_2系统调用的实现
    leetcode:1154. 一年中的第几天(python3解法)
    个人主页汇总 | 私信没空看,建议b站
    Linux 下安装 node.js
    如何将抓取下来的unicode字符串转换为中文
    【无标题】
    宝塔一键安装wordpress
  • 原文地址:https://blog.csdn.net/everyxing1007/article/details/127879275