• ML之shap:基于adult人口普查收入二分类预测数据集(预测年收入是否超过50k)利用Shap值对XGBoost模型实现可解释性案例之详细攻略


    ML之shap:基于adult人口普查收入二分类预测数据集(预测年收入是否超过50k)利用Shap值对XGBoost模型实现可解释性案例之详细攻略

    目录

    基于adult人口普查收入二分类预测数据集(预测年收入是否超过50k)利用Shap值对XGBoost模型实现可解释性案例

    1、定义数据集

    2、数据集预处理

    # 2.1、入模特征初步筛选

    # 2.2、目标特征二值化

    # 2.3、类别型特征编码数字化

    # 2.4、分离特征与标签

    # 2.5、数据集整体切分

    #3、模型训练与推理

    # 3.1、数据集切分

    # 3.2、模型建立并训练

    # 3.3、模型预测

    #4、模型特征重要性解释可视化

    #4.1、全局特征重要性可视化

    # T1、基于模型本身输出特征重要性

    # T2、利用Shap值解释XGBR模型

    #4.2、局部特征重要性可视化

    # (1)、单样本全特征条形图可视化

    # (2)、单转双特征全样本局部独立图散点图可视化

    # (3)、双特征全样本散点图可视化

    # 4.3、模型特征筛选

    # (1)、基于聚类的shap特征筛选可视化

    5、模型预测的可解释性(可主要分析误分类的样本)

    #  5.1、力图可视化分析:可视化单个或多个样本内各个特征贡献度并对比模型预测值——探究误分类样本

    (1)、单个样本力图可视化—对比预测

    (2)、多个样本力图可视化

    #  5.2、决策图可视化分析:模型如何做出决策

    # (1)、单个样本决策图可视化

    # (2)、多个样本决策图可视化


    基于adult人口普查收入二分类预测数据集(预测年收入是否超过50k)利用Shap值对XGBoost模型实现可解释性案例

    1、定义数据集

    dtypes_len: 15

    ageworkclassfnlwgteducationeducation_nummarital_statusoccupationrelationshipracesexcapital_gaincapital_losshours_per_weeknative_countrysalary
    39State-gov77516Bachelors13Never-marriedAdm-clericalNot-in-familyWhiteMale2174040United-States<=50K
    50Self-emp-not-inc83311Bachelors13Married-civ-spouseExec-managerialHusbandWhiteMale0013United-States<=50K
    38Private215646HS-grad9DivorcedHandlers-cleanersNot-in-familyWhiteMale0040United-States<=50K
    53Private23472111th7Married-civ-spouseHandlers-cleanersHusbandBlackMale0040United-States<=50K
    28Private338409Bachelors13Married-civ-spouseProf-specialtyWifeBlackFemale0040Cuba<=50K
    37Private284582Masters14Married-civ-spouseExec-managerialWifeWhiteFemale0040United-States<=50K
    49Private1601879th5Married-spouse-absentOther-serviceNot-in-familyBlackFemale0016Jamaica<=50K
    52Self-emp-not-inc209642HS-grad9Married-civ-spouseExec-managerialHusbandWhiteMale0045United-States>50K
    31Private45781Masters14Never-marriedProf-specialtyNot-in-familyWhiteFemale14084050United-States>50K
    42Private159449Bachelors13Married-civ-spouseExec-managerialHusbandWhiteMale5178040United-States>50K

    2、数据集预处理

    # 2.1、入模特征初步筛选

    df.columns 
     14

    # 2.2、目标特征二值化

    # 2.3、类别型特征编码数字化

    filt_dtypes_len: 13 [('age', 'float32'), ('workclass', 'category'), ('fnlwgt', 'float32'), ('education_Num', 'float32'), ('marital_status', 'category'), ('occupation', 'category'), ('relationship', 'category'), ('race', 'category'), ('sex', 'category'), ('capital_gain', 'float32'), ('capital_loss', 'float32'), ('hours_per_week', 'float32'), ('native_country', 'category')]
     

    # 2.4、分离特征与标签

    df_adult_display

    ageworkclasseducation_nummarital_statusoccupationrelationshipracesexcapital_gaincapital_losshours_per_weeknative_countrysalary
    039State-gov13Never-marriedAdm-clericalNot-in-familyWhiteMale2174040United-States0
    150Self-emp-not-inc13Married-civ-spouseExec-managerialHusbandWhiteMale0013United-States0
    238Private9DivorcedHandlers-cleanersNot-in-familyWhiteMale0040United-States0
    353Private7Married-civ-spouseHandlers-cleanersHusbandBlackMale0040United-States0
    428Private13Married-civ-spouseProf-specialtyWifeBlackFemale0040Cuba0
    537Private14Married-civ-spouseExec-managerialWifeWhiteFemale0040United-States0
    649Private5Married-spouse-absentOther-serviceNot-in-familyBlackFemale0016Jamaica0
    752Self-emp-not-inc9Married-civ-spouseExec-managerialHusbandWhiteMale0045United-States1
    831Private14Never-marriedProf-specialtyNot-in-familyWhiteFemale14084050United-States1
    942Private13Married-civ-spouseExec-managerialHusbandWhiteMale5178040United-States1

    df_adult

    ageworkclasseducation_nummarital_statusoccupationrelationshipracesexcapital_gaincapital_losshours_per_weeknative_countrysalary
    039713411412174040390
    150613240410013390
    23849061410040390
    35347260210040390
    428413210520004050
    537414245400040390
    64945381200016230
    75269240410045391
    83141441014014084050391
    942413240415178040391

    # 2.5、数据集整体切分

    df_len: 32561 ,train_test_index: 30933
    X.shape,y.shape: (30933, 12) (30933,)
    X_test.shape,y_test.shape: (1628, 12) (1628,)

    #3、模型训练与推理

    # 3.1、数据集切分

    # 3.2、模型建立并训练

    # 3.3、模型预测

    ageworkclasseducation_nummarital_statusoccupationrelationshipracesexcapital_gaincapital_losshours_per_weeknative_countryy_val_prediy_val
    1131129494132000603900
    12519334104312186140403911
    292252741341014100453900
    542822492704100403900
    2400327104112000403900
    4319454102404100403910
    2656443492604100403900
    472160013200410083901
    19518296921204100353900
    2501333452604100403900

    #4、模型特征重要性解释可视化

    #4.1、全局特征重要性可视化

    # T1、基于模型本身输出特征重要性

     XGBR_importance_dict: [('age', 130), ('capital_gain', 125), ('education_num', 86), ('capital_loss', 75), ('hours_per_week', 63), ('relationship', 59), ('marital_status', 52), ('occupation', 52), ('workclass', 20), ('sex', 13), ('native_country', 10), ('race', 6)]

    # T2、利用Shap值解释XGBR模型

    利用shap自带的函数实现特征贡献性可视化——特征重要性排序与上边类似,但并不相同

    # (1)、创建Explainer并计算SHAP值

    # T2.1、输出shap.Explanation对象

    # T2,2、输出numpy.array数组

    1. shap2exp.values.shape (30933, 12)
    2. [[ 0.31074238 -0.16607898 0.5617416 ... -0.04660619 -0.09465054
    3. 0.00530914]
    4. [ 0.34912622 -0.16633348 0.65308005 ... -0.06718991 -0.9804511
    5. 0.00515459]
    6. [ 0.21971266 0.02263742 -0.299867 ... -0.0583196 -0.09738331
    7. 0.00415599]
    8. ...
    9. [-0.48140627 0.07019287 -0.30844492 ... -0.04253047 -0.10924102
    10. 0.00649792]
    11. [ 0.39729887 -0.2313431 -0.45257783 ... -0.06502013 0.27416423
    12. 0.00587647]
    13. [ 0.27594262 0.03170239 0.78293955 ... -0.06743324 0.31613
    14. 0.00530914]]
    15. shap2array.shape (30933, 12)
    16. [[ 0.31074238 -0.16607898 0.5617416 ... -0.04660619 -0.09465054
    17. 0.00530914]
    18. [ 0.34912622 -0.16633348 0.65308005 ... -0.06718991 -0.9804511
    19. 0.00515459]
    20. [ 0.21971266 0.02263742 -0.299867 ... -0.0583196 -0.09738331
    21. 0.00415599]
    22. ...
    23. [-0.48140627 0.07019287 -0.30844492 ... -0.04253047 -0.10924102
    24. 0.00649792]
    25. [ 0.39729887 -0.2313431 -0.45257783 ... -0.06502013 0.27416423
    26. 0.00587647]
    27. [ 0.27594262 0.03170239 0.78293955 ... -0.06743324 0.31613
    28. 0.00530914]]
    29. shap2exp.values与shap2array,两个矩阵否相等: True

    # (2)、全样本各特征shap值条形图可视化

     # shap值高阶交互可视化

     

    # (3)、全样本各特征shap值蜂群图可视化

     

     

     

     

    # (4)、全局特征重要性排序散点图可视化

     

     

     

    #4.2、局部特征重要性可视化

    # (1)、单样本全特征条形图可视化

    前测试样本:0

    1. .values =
    2. array([ 0.31074238, -0.16607898, 0.5617416 , -0.58709425, -0.08897061,
    3. -0.6133537 , 0.01539118, 0.04758333, -0.3988452 , -0.04660619,
    4. -0.09465054, 0.00530914], dtype=float32)
    5. .base_values =
    6. -1.3270257
    7. .data =
    8. array([3.900e+01, 7.000e+00, 1.300e+01, 4.000e+00, 1.000e+00, 1.000e+00,
    9. 4.000e+00, 1.000e+00, 2.174e+03, 0.000e+00, 4.000e+01, 3.900e+01])

     

    前测试样本:1

    1. .values =
    2. array([ 0.34912622, -0.16633348, 0.65308005, 0.3069151 , 0.26878497,
    3. 0.5229906 , 0.01030679, 0.04531586, -0.15429462, -0.06718991,
    4. -0.9804511 , 0.00515459], dtype=float32)
    5. .base_values =
    6. -1.3270257
    7. .data =
    8. array([50., 6., 13., 2., 4., 0., 4., 1., 0., 0., 13., 39.])

     

    前测试样本:10

    1. .values =
    2. array([ 0.27578622, 0.02686635, -0.0699547 , 0.2820353 , 0.3097189 ,
    3. 0.55229187, -0.03686382, 0.05135565, -0.1607191 , -0.06321771,
    4. 0.38190693, 0.02023092], dtype=float32)
    5. .base_values =
    6. -1.3270257
    7. .data =
    8. array([37., 4., 10., 2., 4., 0., 2., 1., 0., 0., 80., 39.])

     

    前测试样本:20

    1. .values =
    2. array([ 0.31008577, 0.00316932, 1.3133987 , 0.16768128, 0.18239255,
    3. 0.6863757 , 0.00508371, 0.05159741, -0.15813455, -0.06736177,
    4. 0.31327826, 0.01936885], dtype=float32)
    5. .base_values =
    6. -1.3270257
    7. .data =
    8. array([40., 4., 16., 2., 10., 0., 4., 1., 0., 0., 60., 39.])

     

    # (2)、单转双特征全样本局部独立图散点图可视化

     

     

    # (3)、双特征全样本散点图可视化

     

    # 4.3、模型特征筛选

    # (1)、基于聚类的shap特征筛选可视化

     

    5、模型预测的可解释性(可主要分析误分类的样本)

    提供了预测的细节,侧重于解释单个预测是如何生成的。它可以帮助决策者信任模型,并且解释各个特征是如何影响模型单次的决策。

    #  5.1、力图可视化分析:可视化单个或多个样本内各个特征贡献度对比模型预测值——探究误分类样本

    提供了单一模型预测的可解释性,可用于误差分析,找到对特定实例预测的解释。如样例0所示:
    (1)、模型输出值:5.89;
    (2)、基值:base value即explainer.expected_value,即模型输出与训练数据的平均值;
    (3)、绘图箭头下方数字是此实例的特征值。如Age=39;
    (4)、红色则表示该特征的贡献是正数(将预测推高的特征)蓝色表示该特征的贡献是负数(将预测的特征)。长度表示影响力;箭头越长,特征对输出的影响(贡献)越大。通过 x 轴上刻度值可以看到影响的减少或增加量。

     

    (1)、单个样本力图可视化—对比预测

    输出当前测试样本:0

    1. mode_exp_value: -1.3270257
    2. <IPython.core.display.HTML object>
    3. 输出当前测试样本:0
    4. age 29.0
    5. workclass 4.0
    6. education_num 9.0
    7. marital_status 4.0
    8. occupation 1.0
    9. relationship 3.0
    10. race 2.0
    11. sex 0.0
    12. capital_gain 0.0
    13. capital_loss 0.0
    14. hours_per_week 60.0
    15. native_country 39.0
    16. y_val_predi 0.0
    17. y_val 0.0
    18. Name: 11311, dtype: float64
    19. 输出当前测试样本的真实label: 0
    20. 输出当前测试样本的的预测概率: 0

     

    输出当前测试样本:1

    1. 输出当前测试样本:1
    2. age 33.0
    3. workclass 4.0
    4. education_num 10.0
    5. marital_status 4.0
    6. occupation 3.0
    7. relationship 1.0
    8. race 2.0
    9. sex 1.0
    10. capital_gain 8614.0
    11. capital_loss 0.0
    12. hours_per_week 40.0
    13. native_country 39.0
    14. y_val_predi 1.0
    15. y_val 1.0
    16. Name: 12519, dtype: float64
    17. 输出当前测试样本的真实label: 1
    18. 输出当前测试样本的的预测概率: 1

     

    输出当前测试样本:5 

     

    1. 输出当前测试样本:5
    2. age 45.0
    3. workclass 4.0
    4. education_num 10.0
    5. marital_status 2.0
    6. occupation 4.0
    7. relationship 0.0
    8. race 4.0
    9. sex 1.0
    10. capital_gain 0.0
    11. capital_loss 0.0
    12. hours_per_week 40.0
    13. native_country 39.0
    14. y_val_predi 1.0
    15. y_val 0.0
    16. Name: 4319, dtype: float64
    17. 输出当前测试样本的真实label: 0
    18. 输出当前测试样本的的预测概率: 1

     

    输出当前测试样本:7 

    1. 输出当前测试样本:7
    2. age 60.0
    3. workclass 0.0
    4. education_num 13.0
    5. marital_status 2.0
    6. occupation 0.0
    7. relationship 0.0
    8. race 4.0
    9. sex 1.0
    10. capital_gain 0.0
    11. capital_loss 0.0
    12. hours_per_week 8.0
    13. native_country 39.0
    14. y_val_predi 0.0
    15. y_val 1.0
    16. Name: 4721, dtype: float64
    17. 输出当前测试样本的真实label: 1
    18. 输出当前测试样本的的预测概率: 0

     

     

    (2)、多个样本力图可视化

    # (2.1)、特征贡献度力图可视化,利用深红色深蓝色地图可视化前 5个预测解释,可以使用X数据集。

    # (2.2)、误分类力图可视化,肯定要用X_val数据集,因为涉及到模型预测。
    如果对多个样本进行解释,将上述形式旋转90度然后水平并排放置,得到力图的变体

     

     

    #  5.2、决策图可视化分析:模型如何做出决策

    # (1)、单个样本决策图可视化

    # (2)、多个样本决策图可视化

    # (2.1)、部分样本决策图可视化

    # (2.2)、误分类样本决策图可视化

  • 相关阅读:
    C++学习笔记13 - 浅拷贝和深拷贝
    线性代数
    第3章-线性方程组(3)
    [Power Query] 添加列
    Maven项目,进行编译,使用idea的 编译功能,就是正常的,但是在终端中执行 mvn clean compile 报错
    Springboot 阻止XSS攻击
    WPF开发随笔收录-心电图曲线绘制
    【Java】字节流、字符流、IO异常、属性集
    算法通过村第九关-二分(中序遍历)黄金笔记|二叉搜索树
    实际项目中最常用的设计模式
  • 原文地址:https://blog.csdn.net/qq_41185868/article/details/125631035