• 【Python机器学习项目】项目一:心脏病二分类问题


    使用机器学习预测心脏病

    根据一些病理学属性预测心脏病

    特别说明:

    1. 开新坑啦!本系列共2个项目,难度不大,特别适合新手入坑

    2. 由于本项目只是系列课程的第一个项目,所以很多细节不深挖,仅做示范,在第二个项目中再完善。

    以下为整体思路概述


    1. 问题定义

    给定一个病人的临床诊断,能否预测他们是否患有心脏病?

    2. 数据来源

    https://archive.ics.uci.edu/ml/datasets/Heart+Disease

    3. 评估

    期望准确率达到95%

    4. 特征和标签

    数据字典

    1. age: age in years
    2. sex: sex (1 = male; 0 = female)
    3. cp: chest pain type
    • – Value 0: typical angina
    • – Value 1: atypical angina
    • – Value 2: non-anginal pain
    • – Value 3: asymptomatic
    1. trestbps: resting blood pressure (in mm Hg on admission to the hospital)
    2. chol: serum cholestoral in mg/dl
    3. fbs: (fasting blood sugar > 120 mg/dl) (1 = true; 0 = false)
    4. restecg: resting electrocardiographic results
    • – Value 0: normal
    • – Value 1: having ST-T wave abnormality (T wave inversions and/or ST elevation or depression of > 0.05 mV)
    • – Value 2: showing probable or definite left ventricular hypertrophy by Estes’ criteria
    1. thalach: maximum heart rate achieved
    2. exang: exercise induced angina (1 = yes; 0 = no)
    3. oldpeak = ST depression induced by exercise relative to rest
    4. slope: the slope of the peak exercise ST segment
    • – Value 0: upsloping
    • – Value 1: flat
    • – Value 2: downsloping
    1. ca: number of major vessels (0-3) colored by flourosopy
    2. thal: 0 = normal; 1 = fixed defect; 2 = reversable defect
    3. target: 0 = no disease, 1 = disease

    0. 导包

    # EDA
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    import seaborn as sns
    from scipy import stats
    sns.set()
    plt.rcParams['font.sans-serif'] = ['SimHei']
    plt.rcParams['axes.unicode_minus'] = False
    %config InlineBackend.figure_config = 'svg'
    
    # sklearn模型
    from sklearn.neighbors import KNeighborsClassifier
    from sklearn.linear_model import LogisticRegression
    from sklearn.ensemble import RandomForestClassifier
    
    # 模型评估
    from sklearn.model_selection import train_test_split, cross_val_score
    from sklearn.model_selection import RandomizedSearchCV, GridSearchCV
    from sklearn.metrics import confusion_matrix, classification_report
    from sklearn.metrics import precision_score, recall_score, f1_score
    from sklearn.metrics import plot_roc_curve
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22

    载入数据

    hd_df = pd.read_csv('heart-disease.csv')
    hd_df.shape
    
    • 1
    • 2
    (303, 14)
    
    • 1

    1. EDA

    了解更多有关这个数据集的信息,成为该数据集的懂王

    1. 要解决什么问题?
    2. 都有些什么数据,要怎么处理?
    3. 有无缺失值,如何处理?
    4. 有无异常值,如何处理?
    5. 如何通过创建衍生特征、处理和筛选现有特征得到更多信息?
    hd_df.head()
    
    • 1
    agesexcptrestbpscholfbsrestecgthalachexangoldpeakslopecathaltarget
    063131452331015002.30011
    137121302500118703.50021
    241011302040017201.42021
    356111202360117800.82021
    457001203540116310.62021
    hd_df.tail()
    
    • 1
    agesexcptrestbpscholfbsrestecgthalachexangoldpeakslopecathaltarget
    29857001402410112310.21030
    29945131102640113201.21030
    30068101441931114103.41230
    30157101301310111511.21130
    30257011302360017400.01120
    # 查看样本分布
    targets = hd_df['target'].value_counts()
    targets
    
    • 1
    • 2
    • 3
    1    165
    0    138
    Name: target, dtype: int64
    
    • 1
    • 2
    • 3
    targets.plot(
        kind='bar', 
        color=['salmon', 'lightblue'],
        figsize=(10,6)
    )
    plt.xticks(rotation=0)
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    在这里插入图片描述

    hd_df.info()
    
    • 1
    
    RangeIndex: 303 entries, 0 to 302
    Data columns (total 14 columns):
     #   Column    Non-Null Count  Dtype  
    ---  ------    --------------  -----  
     0   age       303 non-null    int64  
     1   sex       303 non-null    int64  
     2   cp        303 non-null    int64  
     3   trestbps  303 non-null    int64  
     4   chol      303 non-null    int64  
     5   fbs       303 non-null    int64  
     6   restecg   303 non-null    int64  
     7   thalach   303 non-null    int64  
     8   exang     303 non-null    int64  
     9   oldpeak   303 non-null    float64
     10  slope     303 non-null    int64  
     11  ca        303 non-null    int64  
     12  thal      303 non-null    int64  
     13  target    303 non-null    int64  
    dtypes: float64(1), int64(13)
    memory usage: 33.3 KB
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    # 查看缺失值
    hd_df.isna().sum()
    
    • 1
    • 2
    age         0
    sex         0
    cp          0
    trestbps    0
    chol        0
    fbs         0
    restecg     0
    thalach     0
    exang       0
    oldpeak     0
    slope       0
    ca          0
    thal        0
    target      0
    dtype: int64
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    # 查看描述性统计信息
    hd_df.describe([0.01, 0.25, 0.5, 0.75, 0.99]).T
    
    • 1
    • 2
    countmeanstdmin1%25%50%75%99%max
    age303.054.3663379.08210129.035.0047.555.061.071.0077.0
    sex303.00.6831680.4660110.00.000.01.01.01.001.0
    cp303.00.9669971.0320520.00.000.01.02.03.003.0
    trestbps303.0131.62376217.53814394.0100.00120.0130.0140.0180.00200.0
    chol303.0246.26402651.830751126.0149.00211.0240.0274.5406.74564.0
    fbs303.00.1485150.3561980.00.000.00.00.01.001.0
    restecg303.00.5280530.5258600.00.000.01.01.01.982.0
    thalach303.0149.64686522.90516171.095.02133.5153.0166.0191.96202.0
    exang303.00.3267330.4697940.00.000.00.01.01.001.0
    oldpeak303.01.0396041.1610750.00.000.00.81.64.206.2
    slope303.01.3993400.6162260.00.001.01.02.02.002.0
    ca303.00.7293731.0226060.00.000.00.01.04.004.0
    thal303.02.3135310.6122770.01.002.02.03.03.003.0
    target303.00.5445540.4988350.00.000.01.01.01.001.0

    查看性别和标签之间的关系

    hd_df['sex'].value_counts()
    
    • 1
    1    207
    0     96
    Name: sex, dtype: int64
    
    • 1
    • 2
    • 3
    # cross_tab改进版函数
    def to_cross_tab(origin_df, index_name, col_name):
        df = pd.crosstab(origin_df[index_name], origin_df[col_name])
        df['rate'] = df.iloc[:,1] / (df.iloc[:,0] + df.iloc[:,1])
        return df
    
    • 1
    • 2
    • 3
    • 4
    • 5
    sex_target_df = to_cross_tab(hd_df, 'target', 'sex')
    sex_target_df
    
    • 1
    • 2
    sex01rate
    target
    0241140.750000
    172930.449275
    # 方便绘图的函数
    def to_plot(df, title, xlabel, ylabel, legend):
        df.plot(
        kind='bar', 
        color=['lightblue', 'salmon'],
        figsize=(10,6)
    )
        plt.title(title)
        plt.xlabel(xlabel)
        plt.ylabel(ylabel)
        plt.xticks(rotation=0)
        plt.legend(legend)
        plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    to_plot(sex_target_df[[0,1]], '按性别统计的心脏病概率', '0 = 女生, 1 = 男生', '总人数', ['未得病', '得病'])
    
    • 1


    在这里插入图片描述

    明显女性发病率高得多


    查看得病/未得病两类人中年龄和最大心率的关系

    plt.figure(figsize=(10,6))
    
    # 查看得病人群
    plt.scatter(hd_df['age'][hd_df['target']==1],
                hd_df['thalach'][hd_df['target']==1],
                c='salmon'
    )
    
    # 查看未得病人群
    plt.scatter(hd_df['age'][hd_df['target']==0],
                hd_df['thalach'][hd_df['target']==0],
                c='lightblue'
    )
    
    # 说明
    plt.title('根据是否得心脏病分成两类人群来查看年龄和最大心率')
    plt.xlabel('年龄')
    plt.ylabel('最大心率')
    plt.legend(['得病', '未得病'])
    
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21


    在这里插入图片描述

    # 查看年龄分布
    hd_df['age'].hist()
    
    • 1
    • 2
    
    
    • 1


    在这里插入图片描述

    # 做正态性检验
    stats.normaltest(hd_df['age'])
    
    • 1
    • 2
    NormaltestResult(statistic=8.74798581312778, pvalue=0.012600826063683705)
    
    • 1

    年龄符合正态分布


    查看心绞痛类型和标签之间的关系

    1. cp: chest pain type
    • – Value 0: typical angina
    • – Value 1: atypical angina
    • – Value 2: non-anginal pain
    • – Value 3: asymptomatic
    cp_target_df = to_cross_tab(hd_df, 'cp', 'target')
    cp_target_df
    
    • 1
    • 2
    target01rate
    cp
    0104390.272727
    19410.820000
    218690.793103
    37160.695652
    to_plot(cp_target_df[[0,1]], '按心绞痛类型统计的心脏病人数', '心绞痛类型', '总人数', ['未得病', '得病'])
    
    • 1


    在这里插入图片描述

    # 相关系数矩阵
    corr_matrix = hd_df.corr()
    corr_matrix
    
    • 1
    • 2
    • 3
    agesexcptrestbpscholfbsrestecgthalachexangoldpeakslopecathaltarget
    age1.000000-0.098447-0.0686530.2793510.2136780.121308-0.116211-0.3985220.0968010.210013-0.1688140.2763260.068001-0.225439
    sex-0.0984471.000000-0.049353-0.056769-0.1979120.045032-0.058196-0.0440200.1416640.096093-0.0307110.1182610.210041-0.280937
    cp-0.068653-0.0493531.0000000.047608-0.0769040.0944440.0444210.295762-0.394280-0.1492300.119717-0.181053-0.1617360.433798
    trestbps0.279351-0.0567690.0476081.0000000.1231740.177531-0.114103-0.0466980.0676160.193216-0.1214750.1013890.062210-0.144931
    chol0.213678-0.197912-0.0769040.1231741.0000000.013294-0.151040-0.0099400.0670230.053952-0.0040380.0705110.098803-0.085239
    fbs0.1213080.0450320.0944440.1775310.0132941.000000-0.084189-0.0085670.0256650.005747-0.0598940.137979-0.032019-0.028046
    restecg-0.116211-0.0581960.044421-0.114103-0.151040-0.0841891.0000000.044123-0.070733-0.0587700.093045-0.072042-0.0119810.137230
    thalach-0.398522-0.0440200.295762-0.046698-0.009940-0.0085670.0441231.000000-0.378812-0.3441870.386784-0.213177-0.0964390.421741
    exang0.0968010.141664-0.3942800.0676160.0670230.025665-0.070733-0.3788121.0000000.288223-0.2577480.1157390.206754-0.436757
    oldpeak0.2100130.096093-0.1492300.1932160.0539520.005747-0.058770-0.3441870.2882231.000000-0.5775370.2226820.210244-0.430696
    slope-0.168814-0.0307110.119717-0.121475-0.004038-0.0598940.0930450.386784-0.257748-0.5775371.000000-0.080155-0.1047640.345877
    ca0.2763260.118261-0.1810530.1013890.0705110.137979-0.072042-0.2131770.1157390.222682-0.0801551.0000000.151832-0.391724
    thal0.0680010.210041-0.1617360.0622100.098803-0.032019-0.011981-0.0964390.2067540.210244-0.1047640.1518321.000000-0.344029
    target-0.225439-0.2809370.433798-0.144931-0.085239-0.0280460.1372300.421741-0.436757-0.4306960.345877-0.391724-0.3440291.000000
    plt.figure(figsize=(14, 10))
    sns.heatmap(
        corr_matrix, 
        vmin=-1, 
        annot=True, 
        linewidth=5, 
        fmt='.2f', 
        cmap='YlGnBu'
    )
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10


    在这里插入图片描述

    这个相关性看起来还是比较好的,大部分特征和标签之间都有一定的相关性,且特征之间也没有相关性>0.8的需要排除。当然,真的想看相关性还得分类别变量和连续值变量,连续值变量又得做正态检验。


    3. 建模

    hd_df.head()
    
    • 1
    agesexcptrestbpscholfbsrestecgthalachexangoldpeakslopecathaltarget
    063131452331015002.30011
    137121302500118703.50021
    241011302040017201.42021
    356111202360117800.82021
    457001203540116310.62021
    X = hd_df.drop(columns=['target'])
    y = hd_df['target']
    
    • 1
    • 2
    # 设置随机种子,便于其他人重复实验
    np.random.seed(13)
    
    # 划分训练集和测试集
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    依次使用逻辑斯蒂回归、KNN、随机森林

    # 创建字典
    models = {
        'lr': LogisticRegression(),
        'knn': KNeighborsClassifier(),
        'rf': RandomForestClassifier()
    }
    
    # 一个简单的试探性fit和score的函数
    def fit_and_score(models, X_train, X_test, y_train, y_test):
        np.random.seed(13)
        model_score = {}
        for name, model in models.items():
            model.fit(X_train, y_train)
            model_score[name] = model.score(X_test, y_test)
        return model_score
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    model_scores = fit_and_score(models, X_train, X_test, y_train, y_test)
    model_scores
    
    • 1
    • 2
     {'lr': 0.8360655737704918, 'knn': 0.639344262295082, 'rf': 0.819672131147541}
    
    • 1

    模型比较

    model_compare = pd.DataFrame(model_scores, index=['正确率'])
    model_compare.T.plot(kind='bar')
    plt.xticks(rotation=0)
    plt.show()
    
    • 1
    • 2
    • 3
    • 4

    在这里插入图片描述

    接下来做什么?

    • 超参数优化
    • 特征重要性
    • 混淆矩阵
    • 交叉验证
    • 精确率
    • 召回率
    • F1 score
    • 分类报告
    • ROC
    • AUC
    # knn调参(假装不会GSCV和RSCV)
    train_scores = []
    test_scores = []
    
    neighbors = range(1, 21)
    
    knn = KNeighborsClassifier()
    for i in n_neighbors:
        knn.set_params(n_neighbors=i)
        knn.fit(X_train, y_train)
        train_scores.append(knn.score(X_train, y_train))
        test_scores.append(knn.score(X_test, y_test))    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    train_scores
    
    • 1
    [1.0,
     0.8016528925619835,
     0.8057851239669421,
     0.7603305785123967,
     0.768595041322314,
     0.7355371900826446,
     0.7396694214876033,
     0.71900826446281,
     0.7024793388429752,
     0.6900826446280992,
     0.7107438016528925,
     0.6859504132231405,
     0.7024793388429752,
     0.6776859504132231,
     0.6942148760330579,
     0.6859504132231405,
     0.6694214876033058,
     0.6859504132231405,
     0.7024793388429752,
     0.7066115702479339]
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    test_scores
    
    • 1
    [0.6065573770491803,
     0.4426229508196721,
     0.5737704918032787,
     0.5409836065573771,
     0.639344262295082,
     0.6557377049180327,
     0.6065573770491803,
     0.6721311475409836,
     0.6557377049180327,
     0.6557377049180327,
     0.6885245901639344,
     0.6885245901639344,
     0.6885245901639344,
     0.7377049180327869,
     0.7213114754098361,
     0.7213114754098361,
     0.7213114754098361,
     0.7049180327868853,
     0.7377049180327869,
     0.7377049180327869]
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    plt.plot(neighbors, train_scores, label='Train score')
    plt.plot(neighbors, test_scores, label='Test score')
    plt.xticks(range(1,21,1))
    plt.xlabel('n_neighbors参数值')
    plt.ylabel('正确率')
    plt.legend()
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7


    在这里插入图片描述

    knn最高分也没达到80%正确率,放弃


    使用RandomizedSearchCV调参

    # 逻辑斯蒂回归
    # 由于主要是想找最优C值,其他参数就不设置了,并且这里使用np.logspace故意把C值分布得开一些,因为完全不知道在哪里取得最优值
    log_reg_grid = {
        'C':np.logspace(-4, 4, 20),
        'solver': ['liblinear']
    }
    
    # 随机森林
    rf_grid = {
        'n_estimators': np.arange(10, 1000, 50),
        'max_depth': [None, 3, 5, 10],
        'min_samples_split': np.arange(2, 20, 2),
        'min_samples_leaf': np.arange(1, 20, 2)
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    np.random.seed(13)
    
    # 实例化RSCV对象
    rs_log_reg = RandomizedSearchCV(
        LogisticRegression(),
        param_distributions=log_reg_grid,
        cv=5,
        n_iter=20,
        verbose=True
    )
    # fit
    rs_log_reg.fit(X_train, y_train)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    Fitting 5 folds for each of 20 candidates, totalling 100 fits
    
    • 1
    RandomizedSearchCV(cv=5, estimator=LogisticRegression(), n_iter=20,
    
                   param_distributions={'C': array([1.00000000e-04, 2.63665090e-04, 6.95192796e-04, 1.83298071e-03,
       4.83293024e-03, 1.27427499e-02, 3.35981829e-02, 8.85866790e-02,
       2.33572147e-01, 6.15848211e-01, 1.62377674e+00, 4.28133240e+00,
       1.12883789e+01, 2.97635144e+01, 7.84759970e+01, 2.06913808e+02,
       5.45559478e+02, 1.43844989e+03, 3.79269019e+03, 1.00000000e+04]),
                                        'solver': ['liblinear']},
                   verbose=True)
    In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
    On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    rs_log_reg.best_params_
    
    • 1
    {'solver': 'liblinear', 'C': 1.623776739188721}
    
    • 1
    rs_log_reg.score(X_test, y_test)
    
    • 1
    0.819672131147541
    
    • 1

    负提升,难绷,由于只是第一个项目,对调参仅做展示,就不管了

    np.random.seed(13)
    
    # 实例化RSCV对象
    rs_rf = RandomizedSearchCV(
        RandomForestClassifier(),
        param_distributions=rf_grid,
        cv=5,
        n_iter=20,
        verbose=True
    )
    # fit
    rs_rf.fit(X_train, y_train)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    Fitting 5 folds for each of 20 candidates, totalling 100 fits
    
    • 1
    RandomizedSearchCV(cv=5, estimator=RandomForestClassifier(), n_iter=20,
    
                   param_distributions={'max_depth': [None, 3, 5, 10],
                                        'min_samples_leaf': array([ 1,  3,  5,  7,  9, 11, 13, 15, 17, 19]),
                                        'min_samples_split': array([ 2,  4,  6,  8, 10, 12, 14, 16, 18]),
                                        'n_estimators': array([ 10,  60, 110, 160, 210, 260, 310, 360, 410, 460, 510, 560, 610,
       660, 710, 760, 810, 860, 910, 960])},
                   verbose=True)
    In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
    On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
rs_rf.best_params_
{'n_estimators': 310,
 'min_samples_split': 16,
 'min_samples_leaf': 9,
 'max_depth': None}
rs_rf.score(X_test, y_test)
0.8360655737704918

有轻微提升


使用GSCV调参

这次稍微多用点参数

log_reg_grid = {
    'C':np.logspace(-4, 4, 30),
    'solver': ['liblinear', 'sag', 'saga', 'newton-cg', 'lbfgs'],
    'penalty': ['l1', 'l2']
}

# 实例化RSCV对象
gs_log_reg = GridSearchCV(
    LogisticRegression(),
    param_grid=log_reg_grid,
    cv=5,
    verbose=True
)
# fit
gs_log_reg.fit(X_train, y_train)
GridSearchCV(cv=5, estimator=LogisticRegression(),
         param_grid={'C': array([1.00000000e-04, 1.88739182e-04, 3.56224789e-04, 6.72335754e-04,
   1.26896100e-03, 2.39502662e-03, 4.52035366e-03, 8.53167852e-03,
   1.61026203e-02, 3.03919538e-02, 5.73615251e-02, 1.08263673e-01,
   2.04335972e-01, 3.85662042e-01, 7.27895384e-01, 1.37382380e+00,
   2.59294380e+00, 4.89390092e+00, 9.23670857e+00, 1.74332882e+01,
   3.29034456e+01, 6.21016942e+01, 1.17210230e+02, 2.21221629e+02,
   4.17531894e+02, 7.88046282e+02, 1.48735211e+03, 2.80721620e+03,
   5.29831691e+03, 1.00000000e+04]),
                     'penalty': ['l1', 'l2'],
                     'solver': ['liblinear', 'sag', 'saga', 'newton-cg',
                                'lbfgs']},
         verbose=True)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
gs_log_reg.best_params_
{'C': 221.22162910704503, 'penalty': 'l2', 'solver': 'lbfgs'}

和之前的分数一样…

gs_log_reg.score(X_test, y_test)
0.819672131147541

4. 评估

y_pred = gs_log_reg.predict(X_test)
y_pred
array([0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0,
       1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1,
       0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1], dtype=int64)
y_test
203    0
30     1
58     1
90     1
119    1
      ..
249    0
135    1
41     1
67     1
148    1
Name: target, Length: 61, dtype: int64
plot_roc_curve(gs_log_reg, X_test, y_test)


在这里插入图片描述

y_pred==1
array([False,  True,  True,  True,  True, False, False, False,  True,
       False,  True,  True, False, False,  True,  True,  True, False,
       False, False, False, False,  True, False,  True,  True,  True,
        True, False, False,  True, False, False,  True, False, False,
        True, False,  True,  True,  True, False, False,  True, False,
        True,  True, False,  True,  True,  True, False,  True,  True,
        True, False, False,  True,  True,  True,  True])
# 混淆矩阵
def to_confusion_matrix(y_test, y_pred):
    return pd.DataFrame(
        data=confusion_matrix(y_test, y_pred), 
        index=pd.MultiIndex.from_product([['y_test'], [0, 1]]),
        columns=pd.MultiIndex.from_product([['y_pred'], [0, 1]])
    )
cf_matrix = to_confusion_matrix(y_test, y_pred)
cf_matrix
y_pred
01
y_test0215
1629
# 分类报告
print(classification_report(y_test, y_pred))
              precision    recall  f1-score   support

           0       0.78      0.81      0.79        26
           1       0.85      0.83      0.84        35

    accuracy                           0.82        61
   macro avg       0.82      0.82      0.82        61
weighted avg       0.82      0.82      0.82        61

利用交叉验证评估模型

利用交叉验证计算精确率、召回率、F1值

gs_log_reg.best_params_
{'C': 221.22162910704503, 'penalty': 'l2', 'solver': 'lbfgs'}
# 重新实例化逻辑斯蒂回归模型
clf = LogisticRegression(
    C=221.22162910704503, 
    penalty='l2', 
    solver='lbfgs'
)
# 交叉验证正确率
cv_acc = cross_val_score(
    clf, 
    X, 
    y, 
    cv=5,
    scoring='accuracy'
)
cv_acc
array([0.81967213, 0.83606557, 0.85245902, 0.83333333, 0.75      ])
cv_acc = np.mean(cv_acc)
cv_acc
0.8183060109289617
# 交叉验证精确率
cv_precision = cross_val_score(
    clf, 
    X, 
    y, 
    cv=5,
    scoring='precision'
)
cv_precision = np.mean(cv_precision)
cv_precision
0.8088942275474784
# 交叉验证召回率
cv_recall = cross_val_score(
    clf, 
    X, 
    y, 
    cv=5,
    scoring='recall'
)
cv_recall = np.mean(cv_recall)
cv_recall
0.8787878787878787
# 交叉验证F1值
cv_f1 = cross_val_score(
    clf, 
    X, 
    y, 
    cv=5,
    scoring='f1'
)
cv_f1 = np.mean(cv_f1)
cv_f1
0.8413377274453797
# 可视化
cv_metrics = pd.DataFrame(
    {'正确率': cv_acc,
     '精确率': cv_precision,
     '召回率': cv_recall,
     'f1值': cv_f1
    },
    index=[0]
)
cv_metrics.T.plot(
    kind='bar',
    legend=False
)
plt.xticks(rotation=0)
plt.show()


在这里插入图片描述

5. 评估特征重要性

clf.fit(X_train, y_train)
LogisticRegression(C=221.22162910704503)
clf.coef_
array([[ 0.00513208, -1.43253864,  0.78004753, -0.01083726, -0.0019836 ,
         0.0976912 ,  0.71562367,  0.03049414, -0.80027663, -0.44530236,
         0.53599288, -0.66841624, -1.15804589]])
feature_dict = dict(zip(hd_df.columns, clf.coef_[0]))
feature_dict
{'age': 0.005132076982516595,
 'sex': -1.4325386407347098,
 'cp': 0.7800475335340353,
 'trestbps': -0.010837256399792251,
 'chol': -0.001983600334944071,
 'fbs': 0.09769119644464817,
 'restecg': 0.7156236671955836,
 'thalach': 0.030494138473504826,
 'exang': -0.8002766264626233,
 'oldpeak': -0.44530236148020047,
 'slope': 0.5359928831085665,
 'ca': -0.6684162375711792,
 'thal': -1.158045891987526}
feature_df = pd.DataFrame(feature_dict, index=['feature_importance'])
feature_df.T.plot(
    kind='bar',
    title='Feature Importance',
    legend=False,
)
plt.xticks(rotation=30)
plt.show()

在这里插入图片描述


6. 继续实验

如果没有达到预期目标(比如这次定的95%正确率),则继续研究:

如果已经达到了预期目标,想想:
怎么给其他人汇报工作结果?

  • 相关阅读:
    鲲鹏开发者创享日2022:鲲鹏全栈创新 与开发者共建数字湖南
    【云岚到家】-day03-2-门户缓存实现实战
    nginx 多层代理 + k8s ingress 后端服务获取客户真实ip 配置
    JAVA:实现Disjoint Sets不相交集合算法(附完整源码)
    Mysql之innodb
    Flutter CustomPainter实现手写签名并保存为图片且去掉多余空白
    【Java第十八期】:#用Java模拟实现一个单向不带头不循环的链表
    【微机接口】第四章:宏指令程序设计
    platform和led中断项目
    # Toyota Programming Contest 2024#7(AtCoder Beginner Contest 362)
  • 原文地址:https://blog.csdn.net/SpriteNym/article/details/126710192