• 使用LDA(线性判别公式)进行iris鸢尾花的分类


            线性判别分析((Linear Discriminant Analysis ,简称 LDA)是一种经典的线性学习方法,在二分类问题上因为最早由 [Fisher,1936] 提出,亦称 ”Fisher 判别分析“。并且LDA也是一种监督学习的降维技术,也就是说它的数据集的每个样本都有类别输出。这点与主成分和因子分析不同,因为它们是不考虑样本类别的无监督降维技术。

            LDA 的思想非常朴素:给定训练样例集,设法将样例投影到一条直线上,使得同样样例的投影尽可能接近、异样样例的投影点尽可能远离;在对新样本进行分类时,将其投影到同样的这条直线上,再根据投影点的位置来确定新样本的类别。其实可以用一句话概括:就是“投影后类内方差最小,类间方差最大”。
    鸢尾花简介

    iris数据集的中文名是安德森鸢尾花卉数据集,英文全称是Anderson’s Iris data set。iris包含150个样本,对应数据集的每行数据。每行数据包含每个样本的四个特征和样本的类别信息,所以iris数据集是一个150行5列的二维表。

    通俗地说,iris数据集是用来给花做分类的数据集,每个样本包含了花萼长度、花萼宽度、花瓣长度、花瓣宽度四个特征(前4列),我们需要建立一个分类器,分类器可以通过样本的四个特征来判断样本属于山鸢尾、变色鸢尾还是维吉尼亚鸢尾(这三个名词都是花的品种)。

    iris的每个样本都包含了品种信息,即目标属性(第5列,也叫target或label)。

    代码

    1. #首先导入相关库
    2. import sklearn
    3. from sklearn.datasets import load_iris
    4. from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
    5. from sklearn.model_selection import train_test_split
    6. import matplotlib.pyplot as plt
    1. #再进行数据的划分
    2. data = load_iris(return_X_y=True)
    3. x,y = data
    4. #print(x)
    5. #print(y)
    6. #分割训练集和测试集
    7. train_x,test_x,train_y,test_y = train_test_split(x,y,test_size=0.3)
    8. print(train_x.shape)
    9. print(test_x.shape)
    1. #进行训练
    2. LDA = LinearDiscriminantAnalysis()
    3. LDA.fit(train_x,train_y)
    4. y_predict = LDA.predict(test_x)
    5. print(test_y)
    6. print(y_predict)

    相关输出如下

    [2 1 2 1 0 2 2 0 2 0 1 2 1 0 1 0 0 0 0 2 2 1 2 1 0 1 1 2 2 0 2 1 2 0 2 1 2
     1 0 2 0 0 1 0 2]
    [2 1 2 1 0 2 2 0 2 0 1 2 1 0 1 0 0 0 0 2 2 1 2 1 0 1 1 2 2 0 2 1 2 0 2 1 2
     1 0 2 0 0 1 0 2]
    1. #计算预测正确率
    2. j = 0
    3. for i in range(len(test_y)):
    4. if test_y[i] == y_predict[i]:
    5. j = j + 1
    6. print(j)
    7. print(j/len(y_predict))

    画图部分

    1. #由于是按照萼片长度宽度计算,所以将萼片长宽与相应的类别组合成新的列表
    2. total_sepal = []
    3. for i in range(x.shape[0]):
    4. sepal = []
    5. sepal.append(x[i][0])
    6. sepal.append(x[i][1])
    7. sepal.append(y[i])
    8. total_sepal.append(sepal)
    9. print(total_sepal)
    1. #画图
    2. for i in range(x.shape[0]):
    3. if(total_sepal[i][2] == 0):
    4. plt.scatter(total_sepal[i][0], total_sepal[i][1], color='blue')
    5. if(total_sepal[i][2] == 1):
    6. plt.scatter(total_sepal[i][0], total_sepal[i][1], color='red')
    7. if(total_sepal[i][2] == 2):
    8. plt.scatter(total_sepal[i][0], total_sepal[i][1], color='green')
    9. plt.show()

  • 相关阅读:
    LabVIEW车体静强度试验台测控系统
    【软考 系统架构设计师】案例分析⑨ 数据库优化
    mysql兼容微信表情
    前端JavaScript入门到精通,javascript核心进阶ES6语法、API、js高级等基础知识和实战 —— JS基础(五)
    谷歌浏览查询http被自动转化成https导致页面读取失败问题处理
    逆置链表(原地逆置链表)
    分组后比较组内数据
    【Liunx系统编程】命令模式3
    标志寄存器
    Rust——包管理
  • 原文地址:https://blog.csdn.net/qq_36035111/article/details/133149500