• PyTorch教程- 回归问题


    参考:

          主要参考 课时5 简单回归问题-2_哔哩哔哩_bilibili

    系统的回顾一下pytorch


    目录

        1: 简单回归问题

        2: 回归问题实战


    一    简单回归问题(Linear Regression)

              根据预测值,或者标签值不同

              

             线性回归问题

              \hat{y}=wx^T+b

             损失函数

              loss = \frac{1}{2}\sum_{i}(wx_i^T+b-y)^2

             参数学习:

                  梯度下降原理(泰勒公式展开)

                  设z_i=\hat{y_i}-y_i

                   loss=\frac{1}{2} \sum_i z_i^2

                  \frac{\partial l}{\partial w}=\sum_i z_ix_i^T=\sum_{i}(\hat{y_i}-y_i)x_i^T


          二 回归问题实战

            数据集

                     (x_i,y_i) y_i \in R

             模型

                         \hat{y}=w^Tx+b

             参数学习

                          设 z_i=\hat{y_i}-y_i

                          loss= \frac{1}{2N}\sum_{i=1}^{N} z_i^2

              梯度

                         b_{grad}=\frac{1}{N}\sum z_i

                         w_{grad}=\frac{1}{N}\sum_{i}z_ix_i

              参数更新           

                         w=w-\alpha* w_{grad}

                         b= b-\alpha *b_{grad}


    三 例

      3.1 训练部分

    1. # -*- coding: utf-8 -*-
    2. """
    3. Created on Thu Nov 10 21:33:37 2022
    4. @author: cxf
    5. """
    6. import torch
    7. import numpy as np
    8. from torch.utils.data import Dataset, DataLoader
    9. from draw import draw_loss
    10. #需要继承data.Dataset
    11. class MyDataset(Dataset):
    12. def __init__(self, data, target):
    13. self.x = data
    14. self.y = target
    15. self.len = self.x.shape[0] #样本个数
    16. def __getitem__(self, index):
    17. x = self.x[index]
    18. y = self.y[index]
    19. return x,y
    20. def __len__(self):
    21. return self.len
    22. #linear regression
    23. class LR:
    24. '''
    25. 预测值
    26. args
    27. w: 权重系数
    28. b: 偏置系数
    29. '''
    30. def predict(self,w,b,x):
    31. predY= torch.mm(w.T,x)+b
    32. return predY
    33. '''
    34. 梯度更新
    35. args
    36. w_cur: 权重系数
    37. b_cur 偏置
    38. trainX: 训练数据集
    39. trainY: 标签集
    40. '''
    41. def step_gradient(self,w_cur,b_cur, trainX,trainY):
    42. w_gradient = 0
    43. b_gradient = 0
    44. m = trainX.shape[0]
    45. N = float(m)
    46. for i in range(0,m):
    47. x = trainX[i].view(self.n,1)
    48. y = trainY[i]
    49. predY = self.predict(w_cur,b_cur,x)
    50. delta = predY - y
    51. b_gradient +=(2/N)*delta
    52. w_gradient +=(2/N)*delta*x
    53. new_b = b_cur- self.learnRate*b_gradient
    54. new_w = w_cur- self.learnRate*w_gradient
    55. return new_w,new_b
    56. '''
    57. 梯度下降训练
    58. args
    59. dataX: 数据集
    60. dataY: 标签集
    61. '''
    62. def train(self,dataX,dataY):
    63. y_train_loss =[]
    64. b_cur = torch.zeros([1,1],dtype=torch.float)
    65. w_cur = torch.rand((self.n,1),dtype=torch.float)
    66. trainData = MyDataset(dataX, dataY)
    67. train_loader = DataLoader(dataset = trainData, batch_size =self.batch, shuffle = True,drop_last =True)
    68. for epoch in range(self.maxIter):
    69. for step, (batch_x, batch_y) in enumerate(train_loader):
    70. w,b = self.step_gradient(w_cur, b_cur, batch_x,batch_y)
    71. w_cur = w
    72. b_cur = b
    73. loss = self.compute_error(w, b, dataX, dataY)
    74. #print("\n epoch: ",epoch,"\n loss ",loss)
    75. y_train_loss.append(loss)
    76. return y_train_loss
    77. def compute_error(self,w,b, dataX,dataY):
    78. totalError = 0.0
    79. m = len(dataX)
    80. for i in range(0,m):
    81. x = dataX[i].view(self.n,1)
    82. y = dataY[i]
    83. predY = self.predict(w, b, x)
    84. z = predY-y
    85. loss = np.power(z,2)
    86. totalError+=loss
    87. totalError = totalError.numpy()[0,0]
    88. return totalError
    89. '''
    90. 加载数据集
    91. '''
    92. def loadData(self):
    93. data = np.genfromtxt("data.csv",delimiter=",")
    94. trainData = data[:,0:-1]
    95. trainLabel = data[:,-1]
    96. x = torch.tensor(trainData, dtype=torch.float)
    97. y = torch.tensor(trainLabel, dtype = torch.float)
    98. self.m ,self.n=x.shape[0],x.shape[1]
    99. print("\n m ",self.m,"\t n",self.n)
    100. return x,y
    101. def __init__(self):
    102. self.w = 0 #权重系数
    103. self.b = 0 #偏置
    104. self.m = 0 #样本个数
    105. self.n = 0 #样本维度
    106. self.batch = 20 #训练用的样本数
    107. self.maxIter = 1000 #最大迭代次数
    108. self.learnRate = 0.01 #学习率
    109. if __name__ == "__main__":
    110. lr = LR()
    111. x,y = lr.loadData()
    112. loss = lr.train(x, y)
    113. draw_loss(loss)

    3.2 绘图部分

      

    1. # -*- coding: utf-8 -*-
    2. """
    3. Created on Mon Nov 14 20:14:28 2022
    4. @author: cxf
    5. """
    6. import numpy as np
    7. import matplotlib.pyplot as plt
    8. def draw_loss(y_train_loss):
    9. plt.figure()
    10. x_train_loss = range(len(y_train_loss))
    11. # 去除顶部和右边框框
    12. ax = plt.axes()
    13. ax.spines['top'].set_visible(False)
    14. ax.spines['right'].set_visible(False)
    15. #标签
    16. plt.xlabel('iters')
    17. plt.ylabel('accuracy')
    18. plt.plot(x_train_loss, y_train_loss, linewidth=1, linestyle="solid", label="train loss")
    19. plt.legend()
    20. plt.title('train loss')
    21. plt.show()

    3.3 数据部分

    1. # -*- coding: utf-8 -*-
    2. """
    3. Created on Fri Nov 11 22:17:07 2022
    4. @author: cxf
    5. """
    6. import numpy as np
    7. import csv
    8. def makeData():
    9. wT = np.array([[1.0,2.0,2.0]])
    10. b = 0.5
    11. Data = np.random.random((200,3))
    12. m,n = np.shape(Data)
    13. trainData =[]
    14. for i in range(m):
    15. x = Data[i].T
    16. y = np.matmul(wT,x)+b
    17. item =list(x)
    18. item.append(y[0])
    19. trainData.append(item)
    20. return trainData
    21. def save(data):
    22. csvFile = open("data.csv",'w',newline='')
    23. wr = csv.writer(csvFile)
    24. m = len(data)
    25. for i in range(m):
    26. wr.writerow(data[i])
    27. csvFile.close()
    28. makeData()
    29. if __name__ =="__main__":
    30. data = makeData()
    31. save(data)

        

  • 相关阅读:
    力扣第1005题 K 次取反后最大化的数组和 c++ 贪心 双思维
    2024.2.15力扣每日一题——二叉树的层序遍历2
    Java BufferedReader类简介说明
    Unreal Property System (Reflection) 虚幻属性系统(反射)
    SQL练习 2022/7/2
    【java】网络编程
    牛顿-拉夫森算法:用Python实现
    猿创征文|当我在追光 我与光同航--我与Java的技术成长之路
    RabbitMQ 常见问题
    jvm介绍
  • 原文地址:https://blog.csdn.net/chengxf2/article/details/127813459