• 机器学习 —— 计算评估指标


    计算评估指标

    • 假设有100个数据样本,其中有正样本70个,负样本30个
    • 现在模型查出有50个正样本,其中真正的正样本是30个
    • 求:精确率precision,召回率recall, F1值,准确率Accuracy

    TP = 30
    FP = 20
    TN = 10
    FN = 40

    # 精确率(查准率)
    precision = TP / (TP + FP) = 30 / 50 = 0.6
    # 召回率(查全率)
    recall = TP / (TP + FN) = 30 / 70 = 3/7
    # F1值
    f1 = (2 * precision * recall) / (precision + recall) = 0.5
    # 准确率
    accuracy = (TN + TP) / (TN + TP + FN + FP) = 0.4

    画ROC曲线 和 计算auc值

    1. import numpy as np
    2. import pandas as pd
    3. import matplotlib.pyplot as plt
    4. from sklearn.datasets import load_iris
    5. data,target = load_iris(return_X_y=True)
    6. # 二分类
    7. target2 = target[0:100].copy()
    8. data2 = data[:100].copy()

    使用LR模型

    • from sklearn.linear_model import LogisticRegression
    • from sklearn.model_selection import train_test_split
    1. from sklearn.linear_model import LogisticRegression
    2. from sklearn.model_selection import train_test_split
    3. x_train,x_test,y_train,y_test = train_test_split(data2,target2,test_size=0.2)
    4. lr = LogisticRegression()
    5. lr.fit(x_train,y_train)
    6. # 预测
    7. y_pred = lr.predict(x_test)
    8. y_pred
    9. # array([0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1])
    10. # ROC
    11. # metrics:评估
    12. from sklearn.metrics import roc_curve,auc

    ROC 曲线

    1. # y_true:真是结果
    2. # y_score:预测结果
    3. fpr,tpr,_ = roc_curve(y_test,y_pred) # 返回值:fpr,tpr,thresholds
    4. # fpr:伪阳率
    5. # tpr:真阳率
    6. display(fpr,tpr)
    7. '''
    8. array([0., 0., 1.])
    9. array([0., 1., 1.])
    10. '''
    11. plt.plot(fpr,tpr)

    auc

    1. auc(fpr,tpr)
    2. # 1.0

    使用交叉验证来计算auc值,平均auc值

    • from sklearn.model_selection import KFold, StratifiedKFold
    1. from sklearn.model_selection import KFold, StratifiedKFold
    2. skf = StratifiedKFold()
    3. data2.shape
    4. # (100, 4)
    5. list(skf.split(data2,target2))
    6. '''
    7. [(array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26,
    8. 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43,
    9. 44, 45, 46, 47, 48, 49, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70,
    10. 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87,
    11. 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]),
    12. array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 50, 51, 52, 53, 54, 55, 56,
    13. 57, 58, 59])),
    14. (array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 20, 21, 22, 23, 24, 25, 26,
    15. 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43,
    16. 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 70,
    17. 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87,
    18. 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]),
    19. array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 60, 61, 62, 63, 64, 65, 66,
    20. 67, 68, 69])),
    21. (array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    22. 17, 18, 19, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43,
    23. 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60,
    24. 61, 62, 63, 64, 65, 66, 67, 68, 69, 80, 81, 82, 83, 84, 85, 86, 87,
    25. 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]),
    26. array([20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 70, 71, 72, 73, 74, 75, 76,
    27. 77, 78, 79])),
    28. (array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    29. 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 40, 41, 42, 43,
    30. 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60,
    31. 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,
    32. 78, 79, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]),
    33. array([30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 80, 81, 82, 83, 84, 85, 86,
    34. 87, 88, 89])),
    35. (array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    36. 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
    37. 34, 35, 36, 37, 38, 39, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60,
    38. 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,
    39. 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89]),
    40. array([40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 90, 91, 92, 93, 94, 95, 96,
    41. 97, 98, 99]))]
    42. '''
    43. for train,test in skf.split(data2,target2):
    44. x_train = data2[train]
    45. y_train = target2[train]
    46. x_test = data2[test]
    47. y_test = target2[test]
    48. # LR
    49. lr = LogisticRegression()
    50. lr.fit(x_train,y_train)
    51. y_pred = lr.predict(x_test)
    52. # roc
    53. fpr,tpr,_ = roc_curve(y_test,y_pred)
    54. plt.plot(fpr,tpr)
    55. print(auc(fpr,tpr))
    56. '''
    57. 1.0
    58. 1.0
    59. 1.0
    60. 1.0
    61. 1.0
    62. '''

    添加噪声

    • 给data2添加500列随机值
    1. data2.shape
    2. # (100, 4)
    3. data3 = np.random.randn(100,500)
    4. data3.shape
    5. # (100, 500)
    6. # 左右拼接:水平拼接
    7. data4 = np.hstack((data2,data3))
    8. data4.shape
    9. # (100, 504)
    10. skf = StratifiedKFold()
    11. auc_list = []
    12. for train,test in skf.split(data4,target2):
    13. x_train = data4[train]
    14. y_train = target2[train]
    15. x_test = data4[test]
    16. y_test = target2[test]
    17. # LR
    18. lr = LogisticRegression()
    19. lr.fit(x_train,y_train)
    20. # 预测
    21. # y_pred = lr.predict(x_test)
    22. # 预测概率
    23. y_proba = lr.predict_proba(x_test)
    24. print('y_proba:',y_proba)
    25. # roc
    26. fpr,tpr,_ = roc_curve(y_test,y_proba[:,1])
    27. # 画图
    28. plt.plot(fpr,tpr)
    29. print('fpr:',fpr)
    30. print('tpr:',tpr)
    31. print('auc:',auc(fpr,tpr))
    32. print('*'*100)
    33. auc_list.append(auc(fpr,tpr))
    34. # 平均 auc
    35. np.array(auc_list).mean()
    36. '''
    37. y_proba: [[0.3267921 0.6732079 ]
    38. [0.96683557 0.03316443]
    39. [0.77520064 0.22479936]
    40. [0.65359444 0.34640556]
    41. [0.28117064 0.71882936]
    42. [0.51257663 0.48742337]
    43. [0.89757814 0.10242186]
    44. [0.70565166 0.29434834]
    45. [0.95428978 0.04571022]
    46. [0.79620831 0.20379169]
    47. [0.11122497 0.88877503]
    48. [0.14503562 0.85496438]
    49. [0.09769969 0.90230031]
    50. [0.1427527 0.8572473 ]
    51. [0.64864805 0.35135195]
    52. [0.77964905 0.22035095]
    53. [0.50532259 0.49467741]
    54. [0.88917687 0.11082313]
    55. [0.20508718 0.79491282]
    56. [0.22918407 0.77081593]]
    57. fpr: [0. 0. 0. 0.2 0.2 0.3 0.3 0.6 0.6 0.7 0.7 1. ]
    58. tpr: [0. 0.1 0.6 0.6 0.7 0.7 0.8 0.8 0.9 0.9 1. 1. ]
    59. auc: 0.82
    60. ****************************************************************************************************
    61. y_proba: [[0.81694936 0.18305064]
    62. [0.58068561 0.41931439]
    63. [0.95133392 0.04866608]
    64. [0.40420908 0.59579092]
    65. [0.3271581 0.6728419 ]
    66. [0.99027305 0.00972695]
    67. [0.64918216 0.35081784]
    68. [0.90200046 0.09799954]
    69. [0.63054898 0.36945102]
    70. [0.93316453 0.06683547]
    71. [0.53006938 0.46993062]
    72. [0.17861305 0.82138695]
    73. [0.006705 0.993295 ]
    74. [0.09477154 0.90522846]
    75. [0.56917531 0.43082469]
    76. [0.03227622 0.96772378]
    77. [0.22280499 0.77719501]
    78. [0.15966529 0.84033471]
    79. [0.02610573 0.97389427]
    80. [0.01608401 0.98391599]]
    81. fpr: [0. 0. 0. 0.2 0.2 1. ]
    82. tpr: [0. 0.1 0.8 0.8 1. 1. ]
    83. auc: 0.9600000000000001
    84. ****************************************************************************************************
    85. y_proba: [[0.73755142 0.26244858]
    86. [0.81486985 0.18513015]
    87. [0.98155993 0.01844007]
    88. [0.62469409 0.37530591]
    89. [0.86580681 0.13419319]
    90. [0.93865476 0.06134524]
    91. [0.76684129 0.23315871]
    92. [0.26828926 0.73171074]
    93. [0.95379293 0.04620707]
    94. [0.82872899 0.17127101]
    95. [0.0450968 0.9549032 ]
    96. [0.4752642 0.5247358 ]
    97. [0.38068224 0.61931776]
    98. [0.56844634 0.43155366]
    99. [0.49825931 0.50174069]
    100. [0.05526257 0.94473743]
    101. [0.04108483 0.95891517]
    102. [0.00417408 0.99582592]
    103. [0.09069155 0.90930845]
    104. [0.42708884 0.57291116]]
    105. fpr: [0. 0. 0. 0.1 0.1 1. ]
    106. tpr: [0. 0.1 0.5 0.5 1. 1. ]
    107. auc: 0.9500000000000001
    108. ****************************************************************************************************
    109. y_proba: [[0.89441894 0.10558106]
    110. [0.65744045 0.34255955]
    111. [0.67092317 0.32907683]
    112. [0.78029511 0.21970489]
    113. [0.69217484 0.30782516]
    114. [0.97861482 0.02138518]
    115. [0.711046 0.288954 ]
    116. [0.94908913 0.05091087]
    117. [0.62170149 0.37829851]
    118. [0.57082372 0.42917628]
    119. [0.59759391 0.40240609]
    120. [0.53269573 0.46730427]
    121. [0.08361238 0.91638762]
    122. [0.3546565 0.6453435 ]
    123. [0.13494363 0.86505637]
    124. [0.01205661 0.98794339]
    125. [0.04489417 0.95510583]
    126. [0.57049956 0.42950044]
    127. [0.3636283 0.6363717 ]
    128. [0.13165516 0.86834484]]
    129. fpr: [0. 0. 0. 0.1 0.1 1. ]
    130. tpr: [0. 0.1 0.9 0.9 1. 1. ]
    131. auc: 0.99
    132. ****************************************************************************************************
    133. y_proba: [[0.85161531 0.14838469]
    134. [0.9726683 0.0273317 ]
    135. [0.53251231 0.46748769]
    136. [0.72269431 0.27730569]
    137. [0.87414963 0.12585037]
    138. [0.79130481 0.20869519]
    139. [0.98550565 0.01449435]
    140. [0.56034861 0.43965139]
    141. [0.55647585 0.44352415]
    142. [0.72393126 0.27606874]
    143. [0.03734951 0.96265049]
    144. [0.16550755 0.83449245]
    145. [0.28703024 0.71296976]
    146. [0.1594562 0.8405438 ]
    147. [0.07379419 0.92620581]
    148. [0.48656743 0.51343257]
    149. [0.3818963 0.6181037 ]
    150. [0.23117614 0.76882386]
    151. [0.4644294 0.5355706 ]
    152. [0.46337177 0.53662823]]
    153. fpr: [0. 0. 0. 1.]
    154. tpr: [0. 0.1 1. 1. ]
    155. auc: 1.0
    156. ****************************************************************************************************
    157. 0.944
    158. '''

    线性插值

    1. x = np.linspace(0,10,30)
    2. y = np.sin(x)
    3. plt.scatter(x,y)

    1. x2 = np.linspace(0,10,100)
    2. # interp:线性插值
    3. # 让 x2,y2 之间的关系和 x,y之间的关系一样
    4. y2 = np.interp(x2,x,y)
    5. plt.scatter(x,y)
    6. plt.scatter(x2,y2,marker='*')

    计算平均AUC值,和平均ROC曲线

    • auc <= 0.5 : 模型很差
    • auc > 0.6 : 模型一般
    • auc > 0.7 : 模型还可以
    • auc > 0.8 : 模型较好
    • auc > 0.9 : 模型非常好

     

    1. # 算平均AUC值
    2. np.array(auc_list).mean()
    3. # 0.944
    4. # 相当于 x 轴
    5. fprs = np.linspace(0,1,101)
    6. tprs_list = []
    7. auc_list = []
    8. for train,test in skf.split(data4,target2):
    9. x_train = data4[train]
    10. y_train = target2[train]
    11. x_test = data4[test]
    12. y_test = target2[test]
    13. # LR
    14. lr = LogisticRegression()
    15. lr.fit(x_train,y_train)
    16. # 预测
    17. # y_pred = lr.predict(x_test)
    18. # 预测概率
    19. y_proba = lr.predict_proba(x_test)
    20. # roc
    21. fpr,tpr,_ = roc_curve(y_test,y_proba[:,1])
    22. auc_ = auc(fpr,tpr)
    23. auc_list.append(auc_)
    24. # 画图
    25. plt.plot(fpr,tpr,ls='--',label=f'auc:{np.round(auc_,2)}')
    26. # 线性插值
    27. # 让 fprs 与 tprs 的关系和 fpr 与 tpr 的关系一样
    28. tprs = np.interp(fprs,fpr,tpr)
    29. tprs_list.append(tprs)
    30. # 平均 tprs
    31. tprs_mean = np.array(tprs_list).mean(axis=0)
    32. auc_mean = np.array(auc_list).mean()
    33. # 画平均ROC图
    34. plt.plot(fprs,tprs_mean,label=f'auc_mean:{np.round(auc_mean,2)}')

  • 相关阅读:
    Android /android_vendor.32_arm64_armv8-a_shared/libtinyals a.so.abidiff报错
    【前端面试常问】MVC与MVVM
    (209)Verilog HDL:设计一个电路之Rule 90
    STM32F103C8/BT6 USART1 DMA发送失败
    【PAT】数据结构树和图月考复习1
    鸿蒙项目实战-月木学途:1.编写首页,包括搜索栏、轮播图、宫格
    elasticsearch 安装教程
    shardingsphere分库分表示例(逻辑表,真实表,绑定表,广播表,单表)
    面试十三、malloc 、calloc、realloc以及new的区别
    关于类的继承
  • 原文地址:https://blog.csdn.net/qq_52421831/article/details/127849837