• 神经网络之万能定理python-pytorch实现,可以拟合任意曲线


    神经网络之万能定理python-pytorch实现,可以拟合任意曲线

    博主,这几天一直在做这个曲线拟合的实验,讲道理,网上可能也有很多这方面的资料,但是博主其实试了很多,效果只能对一般的曲线还行,稍微复杂一点的,效果都不太好,后来博主经过将近一天奋战终于得到了这个最好的结果:

    代码:

    from turtle import shape
    import torch
    from torch import nn
    import pandas as pd
    import numpy as np
    from scipy.optimize import curve_fit
    import matplotlib.pyplot as plt
    from  utils import  parameters
    from scipy.optimize import leastsq
    from turtle import title
    import numpy as np
    import matplotlib.pyplot as plt
    import torch as t
    from torch.autograd import Variable as var
    
    
    class BP(t.nn.Module):
        def __init__(self):
            super(BP,self).__init__()
            self.linear1 = t.nn.Linear(1,100)
            self.s = t.nn.Sigmoid()
            self.linear2 = t.nn.Linear(100,10)
            self.relu = t.nn.Tanh()
            self.linear3 = t.nn.Linear(10,1)
            self.Dropout = t.nn.Dropout(p = 0.1)
            self.criterion = t.nn.MSELoss()
            self.opt = t.optim.SGD(self.parameters(),lr=0.01)
        def forward(self, input):
            y = self.linear1(input)
            y = self.relu(y)
         #  y=self.Dropout(y)
            y = self.linear2(y)
            y = self.relu(y)
           # y=self.Dropout(y)
            y = self.linear3(y)
            y = self.relu(y)
            return y
    
    
    class BackPropagationEx:
        def __init__(self):
            self.popt=[]
        #def  fun(self,t,a,Smax,S0,t0):
        #           return Smax - (Smax-S0) * np.exp(-a * (t-t0));
        def curve_fitm(self,x,y,epoch):
          
            xs =x.reshape(-1,1)
    
    
            xs=(xs-xs.min())/(xs.max()-xs.min())
           
    
         #   print(xs)
            ys = y
            ys=(ys-ys.min())/(ys.max()-ys.min())
            xs = var(t.Tensor(xs))
    
            ys = var(t.Tensor(ys))
         #   bp = BP(traindata=traindata,labeldata=labeldata,node=[1,6,1],epoch=1000,lr=0.01)
          #  predict=updata(10,traindata,labeldata)
            model=BP()
            for e in range(epoch):
             #   print(e)
                index=0
                ls=0
                for x in xs:
                    y_pre = model(x)
                   #   print(y_pre)
                    loss = model.criterion(y_pre,ys[index])
                    index=index+1
             #       print("loss",loss)
                   
                    ls=ls+loss
                    # Zero gradients
                    model.opt.zero_grad()
                    # perform backward pass
                    loss.backward()
                    # update weights
                    model.opt.step()
                if(e%2==0 ):
                        print(e,ls)
            ys_pre = model(xs)
            loss = model.criterion(y_pre,ys)
            print(loss)
    
    
    
    
            plt.title("curve")
            plt.plot(xs.data.numpy(),ys.data.numpy(),label="ys")
            plt.plot(xs.data.numpy(),ys_pre.data.numpy(),label="ys_pre")
            plt.legend()
            plt.show()
    
        def predict(self,x):
            return self.fun(x,*self.popt)
        def plot(self,x,y,predict):
             plt.plot(x,y,'bo')
            #绘制拟合后的点
             plt.plot(x,predict,'r-')#拟合的参数通过*popt传入
             plt.title("BP神经网络")
             plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102

    来看一下结果:
    在这里插入图片描述

    你们可能觉得这个拟合好像也一般啊,其实不是,我这个问题非常难,基本上网上的代码都是拟合效果很差的,数据的话,感兴趣的,可以私聊我,我可以发给你们。
    这个实现想做到博主这个效果的,很难,因为博主做了大量实现,发现,其实严格意义上的万能定理的实现其实是需要很多的考虑的。
    另外随着训练轮数和神经元的增加,实际上我们的效果可以真正实现万能定理。

  • 相关阅读:
    网页轮播图
    计算机视觉 | 交通信号灯状态的检测和识别
    JavaWeb基础10——VUE&Element&整合Javaweb的商品管理系统
    几个Caller-特性的妙用
    LED驱动IC:HC2160,内置60V功率MOS升压型LED恒流驱动IC。供应LED灯杯单节电池以上供电的LED灯串平板显示LED背光大功率LED照明
    黑豹程序员-架构师学习路线图-百科:Lombok消除冗长的java代码
    transformer一统天下?depth-wise conv有话要说
    联想小新如果使用蓝牙鼠标在关闭了触摸板的情况下不小心关掉了蓝牙该如何处理?
    如何根据不同需求给Word文档设置保护?
    1.ROS编程学习:helloworld的c++与python实现
  • 原文地址:https://blog.csdn.net/weixin_43327597/article/details/136408370