• 机器学习实验一:使用 Logistic 回归来预测患有疝病的马的存活问题


    代码: 

    1. import pandas as pd
    2. import numpy as np
    3. from sklearn.preprocessing import StandardScaler
    4. from sklearn.linear_model import LogisticRegression
    5. from sklearn.metrics import classification_report
    6. import matplotlib.pyplot as plt
    7. def train():
    8. # 1)读取数据:
    9. df1=pd.read_csv('horseColicTraining.txt',delimiter='\t',header=None)
    10. df2=pd.read_csv('horseColicTest.txt',delimiter='\t',header=None)
    11. last_column = df1.iloc[:, -1] # 获取最后一列数据
    12. x_train1 = df1.iloc[:, :-1] # 第一个DataFrame包含除最后一列以外的所有列
    13. y_train1 = pd.DataFrame(last_column) # 第二个DataFrame只包含最后一列
    14. last_column1 = df2.iloc[:, -1] # 获取最后一列数据
    15. x_test1 = df2.iloc[:, :-1] # 第一个DataFrame包含除最后一列以外的所有列
    16. y_test1 = pd.DataFrame(last_column1) # 第二个DataFrame只包含最后一列
    17. # 2)缺失值处理:
    18. #3)划分数据集:
    19. # 筛选特征值和目标值
    20. # 4)特征工程标准化
    21. transfer=StandardScaler()
    22. x_train=transfer.fit_transform(x_train1)
    23. # print(x_train)
    24. x_test=transfer.transform(x_test1)
    25. # transfer1=StandardScaler()
    26. # y_train=transfer.fit_transform(y_train1)
    27. # y_test=transfer.transform(y_test1)
    28. # 二维数组
    29. two_dimensional_array = np.array(y_train1)
    30. # 使用flatten()函数将二维数组转换为一维数组
    31. y_train = two_dimensional_array.flatten()
    32. # print(y_train)
    33. # 5)逻辑回归的预估器:
    34. estimator=LogisticRegression(C=0.04,max_iter=10000)
    35. estimator.fit(x_train,y_train)
    36. # 回归系数和偏置
    37. print('回归系数为:\n',estimator.coef_)
    38. print('偏置为:',estimator.intercept_)
    39. # 6)分类模型的评估
    40. y_predict=estimator.predict(x_test)
    41. print('测试集的预测值为:\n',y_predict)
    42. error=estimator.score(x_test,y_test1)
    43. print('模型预测准确率为:',error)
    44. # 查看精确率和召回率和F1—score
    45. report=classification_report(y_test1,y_predict,labels=[1,0],target_names=['死亡','没死'])
    46. print(report)#precision:精确率 recall:召回率 f1-score support:数量
    47. return y_predict,y_test1
    48. y1,y2=train()
    49. # print(y)
    50. # plt.plot(np.linspace(0,67,67),y)
    51. fig=plt.figure()
    52. plt.scatter(np.linspace(0,67,67),y1,alpha=0.5)
    53. plt.scatter(np.linspace(0,67,67),y2,alpha=0.5)
    54. plt.show()

     结果可视化:(随便写的一个)

  • 相关阅读:
    DAY04-网页布局实战&常用HTML标签&完整盒模型
    CRM系统对科技企业有哪些帮助
    Java 17 VS Java 8: 新旧对决,这些Java 17新特性你不容错过
    第二章:Qt下载与安装 之 2.2 Qt安装
    ORB-SLAM2 ---- ORBmatcher::SearchForTriangulation函数
    【计算机网络】什么是http?
    RocketMq4 消息发送示例及源码浅阅
    9.1 运用API创建多线程
    解锁Mysql中的JSON数据类型,怎一个爽字了得
    Java 中项目路径映射物理机磁盘路径配置。
  • 原文地址:https://blog.csdn.net/qq_46103282/article/details/133207768