• 深度学习 神经网络(5)逻辑回归二分类-Pytorch实现乳腺癌预测


    一、前言

    本文主要介绍了pytorch构造神经网络来实现乳腺癌的预测。

    乳腺癌预测是神经网络应用于逻辑回归二分类问题的一个典型案例。

    跟线性回归的区别在于使用sigmoid激活函数输出。关于该函数可以参考我的另一篇文章《sigmoid函数及其图像绘制》。

    我们使用的是sklearn的乳腺癌数据集。该数据集有30个特征,输出0或1,表示是否患有乳腺癌。

    二、代码实现

    接下来我们使用Pytorch的Sequential方法实现神经网络。

    2.1 引入依赖库

    from sklearn import datasets
    from sklearn.preprocessing import MinMaxScaler
    from sklearn.model_selection import train_test_split
    import torch
    import torch.nn as nn
    import pandas as pd
    import warnings
    #忽略警告
    warnings.filterwarnings('ignore')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    2.2 加载并查看数据集

    dataset= datasets.load_breast_cancer()
    X = dataset.data
    Y = dataset.target
    
    #加载数据集
    data_df = pd.DataFrame(dataset.data, columns=dataset.feature_names)
    data_df['result'] = dataset.target
    data_df.head(10)
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    在这里插入图片描述

    2.3 数据处理

    #数据归一化
    X=MinMaxScaler().fit_transform(X)
    data_df = pd.DataFrame(X, columns=dataset.feature_names)
    
    data_df.head(10)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    在这里插入图片描述

    2.4 数据分割

    # 将数据分割为训练和验证数据,都有特征和预测目标值
    # 分割基于随机数生成器。为random_state参数提供一个数值可以保证每次得到相同的分割
    X_train, X_test, y_train, y_test = train_test_split(X, Y, random_state = 0)
    print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)
    
    • 1
    • 2
    • 3
    • 4
    (426, 30) (143, 30) (426,) (143,)
    
    • 1

    2.5 迭代训练

    x=torch.tensor(X_train,dtype=torch.float32)
    y=torch.tensor(y_train,dtype=torch.float32)
    
    # 把标签转为2维
    y = y.view(y.shape[0],1)
    
    #迭代次数
    epochs=1000
    
    #学习率
    learning_rate=0.5
    
    plt_epoch=[]
    plt_loss=[]
    
    model = nn.Sequential(
        nn.Linear(x.size()[1], 10),
        nn.Sigmoid(),
        nn.Linear(10, 1),
        nn.Sigmoid()
    )
    
    #损失函数
    cost=nn.BCELoss()
    #迭代优化器
    optmizer=torch.optim.SGD(model.parameters(),lr=learning_rate)
    
    for epoch in range(epochs):
        #预测结果
        predictions=model(x) #调用__call__函数
    
        #计算损失值
        loss=cost(predictions,y)
    
        #在反向传播前先把梯度清零
        optmizer.zero_grad()
    
        #反向传播,计算各参数对于损失loss的梯度
        loss.backward()
    
        #根据刚刚反向传播得到的梯度更新模型参数
        optmizer.step()
    
        plt_epoch.append(epoch)
        plt_loss.append(loss.item())
    
        #打印损失值
        if epoch%100==0:
            print('epoch:',epoch,'loss:',loss.item())
    
    
    #绘制迭代次数与损失函数的关系
    import matplotlib.pyplot as plt
    plt.plot(plt_epoch,plt_loss)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    epoch: 0 loss: 0.8293326497077942
    epoch: 100 loss: 0.5075534582138062
    epoch: 200 loss: 0.24081537127494812
    epoch: 300 loss: 0.16377314925193787
    epoch: 400 loss: 0.13307765126228333
    epoch: 500 loss: 0.11592639237642288
    epoch: 600 loss: 0.10465363413095474
    epoch: 700 loss: 0.09659579396247864
    epoch: 800 loss: 0.09052523970603943
    epoch: 900 loss: 0.0857650563120842
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    在这里插入图片描述

    2.6 数据验证

    #测试数据
    x_t=torch.tensor(X_test,dtype=torch.float32)
    y_t=torch.tensor(y_test,dtype=torch.float32)
    # 把标签转为2维
    y_t = y_t.view(y_t.shape[0],1)
    
    #预测结果
    predictions=model(x_t)
    #计算损失值
    loss=cost(predictions,y_t)
    
    print('loss:',loss.detach().item())
    
    predictions=torch.where(predictions>0.5,1,0)
    print(f"预测准确率: {(torch.sum(predictions == y_t)/y_t.size()[0]) * 100}%")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    loss: 0.07796578109264374
    预测准确率: 97.20279693603516%
    
    • 1
    • 2
  • 相关阅读:
    数据化运营19 传播(上):如何打造千万级的私域运营体系?
    mysql数据库基础:视图、变量
    CMake系列(九) CMake 头文件接口库编译及使用
    spring boot 之 整合 knife4j 在线接口文档
    汽车信息安全导图
    多线程必知必会的知识点
    【Flink实战】玩转Flink里面核心的Sink Operator实战
    大学生游戏静态HTML网页设计 (HTML+CSS+JS仿英雄联盟网站15页)
    Springboot项目:连接mysql数据库,使用aop进行日志捕获
    选Redis做 mq 的人,是水平欠缺么?
  • 原文地址:https://blog.csdn.net/Leytton/article/details/127581795