• Python实现线性判别分析教程


    当有一组预测变量需要被分为两个类,一般使用逻辑回归模型。举例,使用信用分和平均存款余额预测贷款是否违约。但当预测变量有多种可能时,则一般会使用线性判别分析(linear discriminant analysis, 简称 LDA).

    线性判别分析

    线性判别分析的场景举例:
    给定高校篮球运动员的场均篮板和得分,预测他们会被三个高校中的一个录取。虽然LDA和逻辑回归模型都可以进行分类。实践表明,在对多个类进行预测时,LDA比逻辑回归要稳定得多,因此当响应变量有两个以上类别时,LDA是首选的算法。与逻辑回归相比,当样本量较小时LDA的表现也更好,这让它成为无法收集大样本时的首选方法。

    构建LDA模型

    线性判别算法对数据有一些要求:

    • 响应变量必须是类别变量。线性判别是分类算法,因此响应变量应该是类别变量。

    • 预测变量应遵循正太分布。首先检查每个预测变量是否大致符合正太分布,如果不满足,需要选择转换算法使其近似满足。

    • 每个预测变量有相同的标准差。现实中很难能够满足该条件,但我们可以对数据进行标准化,让变量统一为标准差为1,均值为0.

    • 检查异常值。在用于LDA之前要检查异常值。可以简单通过箱线图或散点图查进行检测。

    一旦这些假设满足,LDA会估计下面值:

    μ k {μ_k} μk: 第 k t h {k^{th}} kth类所有训练集的均值.

    σ 2 {σ^2} σ2: 第k类样本方差的加权平均值.

    π k {π_k} πk: 属于第k类的训练观察值的比例.

    然后LDA将这些数字代入以下公式,并将每个观测值X = X分配给公式产生最大值的类:

    D k ( x ) = x ∗ ( μ k / σ 2 ) – ( μ k 2 / 2 σ 2 ) + l o g ( π k ) {D_k(x) = x * (μ_k/σ^2) – (μ_k^2/2σ^2) + log(π_k)} Dk(x)=x(μk/σ2)(μk2/2σ2)+log(πk)

    注意,LDA的名称中有线性,因为上面函数产生的值来自x的线性函数的结果。

    LDA应用场景

    LDA模型在现实中应用广泛,下面简单举例:

    市场营销
    零售公司经常使用LDA将购物者分为几类。然后利用建立LDA模型来预测特定购物者是低消费者、中等消费者还是高消费者,使用预测变量如收入、年度总消费额和家庭人数等变量。

    医学领域
    医院或医疗机构的研究人员通常利用LDA预测给定一组异常细胞是否会导致轻微、中度或严重疾病。

    产品研发
    一些公司会利用LDA模型预测消费者属于每天、每周、每月或年使用他们的产品,基于预测变量有性别、年度收入、使用类似产品的频率。

    生态领域
    研究者利用LDA模型预测是否给定珊瑚礁的健康状况:好、中等、坏、严重。预测变量包括大小、年度污染情况、年份。

    示例

    下面分步实现LDA,首先加载必要的包。

    加载工具包

    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    from sklearn import datasets
    from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
    from sklearn.model_selection import RepeatedStratifiedKFold
    from sklearn.model_selection import cross_val_score
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    加载数据

    本实例使用iris数据集,下面代码展示如何加载数据,并转为DataFrame:

    # load iris dataset
    iris = datasets.load_iris()
    
    # convert dataset to pandas DataFrame
    df = pd.DataFrame(data=np.c_[iris['data'], iris['target']],
                      columns=iris['feature_names'] + ['target'])
    df['species'] = pd.Categorical.from_codes(iris.target, iris.target_names)
    df.columns = ['s_length', 's_width', 'p_length', 'p_width', 'target', 'species']
    
    # view first six rows of DataFrame
    print(df.head())
    
    print(len(df.index))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    输出结果:

    #    s_length  s_width  p_length  p_width  target species
    # 0       5.1      3.5       1.4      0.2     0.0  setosa
    # 1       4.9      3.0       1.4      0.2     0.0  setosa
    # 2       4.7      3.2       1.3      0.2     0.0  setosa
    # 3       4.6      3.1       1.5      0.2     0.0  setosa
    # 4       5.0      3.6       1.4      0.2     0.0  setosa
    
    # 150 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    我们看到包括150条观测记录,下面构建LDA预测属于那个分类。

    预测变量为:

    Sepal length
    Sepal width
    Petal length
    Petal width

    结果分类包括:

    setosa
    versicolor
    virginica

    拟合模型

    使用LinearDiscriminantAnalysis 函数拟合模型:

    # define predictor and response variables
    x = df[['s_length', 's_width', 'p_length', 'p_width']]
    y = df['species']
    
    # Fit the LDA model
    model = LinearDiscriminantAnalysis()
    model.fit(x.values, y.values)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    使用模型进行预测

    我们已经拟合了模型,为了评估模型,使用k折分组教程验证. 使用10个分组,重复3次:

    # Define method to evaluate model
    cv = RepeatedStratifiedKFold(n_splits=10, n_repeats=3, random_state=1)
    
    # evaluate model
    scores = cross_val_score(model, x, y, scoring='accuracy', cv=cv, n_jobs=-1)
    print(np.mean(scores))
    
    # 0.9800000000000001
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    模型平均准确率为:98%。 下面使用测试数据进行预测:

    # define new observation
    new = [5, 3, 1, .4]
    
    # predict which class the new observation belongs to
    model.predict([new])
    
    # array([0]) 即第一类:setosa
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    可视化结果

    最后使用LDA图查看线性判别结果:

    # define data to plot
    x = iris.data
    y = iris.target
    model = LinearDiscriminantAnalysis()
    data_plot = model.fit(x, y).transform(x)
    target_names = iris.target_names
    
    # create LDA plot
    plt.figure()
    colors = ['red', 'green', 'blue']
    lw = 2
    for color, i, target_name in zip(colors, [0, 1, 2], target_names):
        plt.scatter(data_plot[y == i, 0], data_plot[y == i, 1], alpha=.8, color=color,
                    label=target_name)
    
    # add legend to plot
    plt.legend(loc='best', shadow=False, scatterpoints=1)
    
    # display LDA plot
    plt.show()
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21

    在这里插入图片描述

  • 相关阅读:
    Mysql整理-索引
    java计算机毕业设计好物网站MyBatis+系统+LW文档+源码+调试部署
    Kubernetes亲和性学习笔记
    Spring构造注入的几种方式
    Spring Cloud Circuit Breaker 使用示例
    2023年【电工(技师)】试题及解析及电工(技师)模拟考试题
    项目打包优化
    Linux 环境搭建以及xshell远程连接
    php字符串处理函数的使用
    怎么语音转文字?快来看看这些方法
  • 原文地址:https://blog.csdn.net/neweastsun/article/details/126587861