• 不同优化器的应用


     简单用用,优化器具体参考

    深度学习中的优化器原理(SGD,SGD+Momentum,Adagrad,RMSProp,Adam)_哔哩哔哩_bilibili

    收藏版|史上最全机器学习优化器Optimizer汇总 - 知乎 (zhihu.com)

    1. import numpy as np
    2. import matplotlib.pyplot as plt
    3. import torch
    4. # prepare dataset
    5. # x,y是矩阵,3行1列 也就是说总共有3个数据,每个数据只有1个特征
    6. x_data = torch.tensor([[1.0], [2.0], [3.0]])
    7. y_data = torch.tensor([[2.0], [4.0], [6.0]])
    8. loss_SGD = []
    9. loss_Adagrad = []
    10. loss_Adam = []
    11. loss_Adamax = []
    12. loss_ASGD = []
    13. loss_LBFGS = []
    14. loss_RMSprop = []
    15. loss_Rprop = []
    16. class LinearModel(torch.nn.Module):
    17. def __init__(self):
    18. super().__init__()
    19. self.Linear = torch.nn.Linear(1,1)
    20. def forward(self,x):
    21. y_pred = self.Linear(x)
    22. return y_pred
    23. model = LinearModel()
    24. criterion = torch.nn.MSELoss(reduction='sum')
    25. optimizer_SGD = torch.optim.SGD(model.parameters(),lr=0.01)
    26. optimizer_Adagrad = torch.optim.SGD(model.parameters(),lr=0.01)
    27. optimizer_Adam = torch.optim.SGD(model.parameters(),lr=0.01)
    28. optimizer_Adamax = torch.optim.SGD(model.parameters(),lr=0.01)
    29. optimizer_ASGD = torch.optim.SGD(model.parameters(),lr=0.01)
    30. optimizer_LBFGS = torch.optim.SGD(model.parameters(),lr=0.01)
    31. optimizer_RMSprop = torch.optim.SGD(model.parameters(),lr=0.01)
    32. optimizer_Rprop = torch.optim.SGD(model.parameters(),lr=0.01)
    33. epoch_list = []
    34. # optimizer_SGD
    35. for epoch in range(100):
    36. y_pred = model(x_data)
    37. loss = criterion(y_pred,y_data)
    38. epoch_list.append(epoch)
    39. loss_SGD.append(loss.data)
    40. optimizer_SGD.zero_grad()
    41. loss.backward()
    42. optimizer_SGD.step()
    43. # optimizer_Adagrad
    44. for epoch in range(100):
    45. y_pred = model(x_data)
    46. loss = criterion(y_pred,y_data)
    47. loss_Adagrad.append(loss.data)
    48. optimizer_Adagrad.zero_grad()
    49. loss.backward()
    50. optimizer_Adagrad.step()
    51. # optimizer_Adam
    52. for epoch in range(100):
    53. y_pred = model(x_data)
    54. loss = criterion(y_pred,y_data)
    55. loss_Adam.append(loss.data)
    56. optimizer_Adam.zero_grad()
    57. loss.backward()
    58. optimizer_Adam.step()
    59. # optimizer_Adamax
    60. for epoch in range(100):
    61. y_pred = model(x_data)
    62. loss = criterion(y_pred,y_data)
    63. loss_Adamax.append(loss.data)
    64. optimizer_Adamax.zero_grad()
    65. loss.backward()
    66. optimizer_Adamax.step()
    67. # optimizer_ASGD
    68. for epoch in range(100):
    69. y_pred = model(x_data)
    70. loss = criterion(y_pred,y_data)
    71. loss_ASGD.append(loss.data)
    72. optimizer_ASGD.zero_grad()
    73. loss.backward()
    74. optimizer_ASGD.step()
    75. # optimizer_LBFGS
    76. for epoch in range(100):
    77. y_pred = model(x_data)
    78. loss = criterion(y_pred,y_data)
    79. loss_LBFGS.append(loss.data)
    80. optimizer_LBFGS.zero_grad()
    81. loss.backward()
    82. optimizer_LBFGS.step()
    83. # optimizer_RMSprop
    84. for epoch in range(100):
    85. y_pred = model(x_data)
    86. loss = criterion(y_pred,y_data)
    87. loss_RMSprop.append(loss.data)
    88. optimizer_RMSprop.zero_grad()
    89. loss.backward()
    90. optimizer_RMSprop.step()
    91. # optimizer_Rprop
    92. for epoch in range(100):
    93. y_pred = model(x_data)
    94. loss = criterion(y_pred,y_data)
    95. loss_Rprop.append(loss.data)
    96. optimizer_Rprop.zero_grad()
    97. loss.backward()
    98. optimizer_Rprop.step()
    99. x_test = torch.tensor([4.0])
    100. y_test = model(x_test)
    101. print('y_pred = ', y_test.data)
    102. plt.subplot(241)
    103. plt.title("SGD")
    104. plt.plot(epoch_list,loss_SGD)
    105. plt.ylabel('cost')
    106. plt.xlabel('epoch')
    107. plt.subplot(242)
    108. plt.title("Adagrad")
    109. plt.plot(epoch_list,loss_Adagrad)
    110. plt.ylabel('cost')
    111. plt.xlabel('epoch')
    112. plt.subplot(243)
    113. plt.title("Adam")
    114. plt.plot(epoch_list,loss_Adam)
    115. plt.ylabel('cost')
    116. plt.xlabel('epoch')
    117. plt.subplot(244)
    118. plt.title("Adamax")
    119. plt.plot(epoch_list,loss_Adamax)
    120. plt.ylabel('cost')
    121. plt.xlabel('epoch')
    122. plt.subplot(245)
    123. plt.title("ASGD")
    124. plt.plot(epoch_list,loss_ASGD)
    125. plt.ylabel('cost')
    126. plt.xlabel('epoch')
    127. plt.subplot(246)
    128. plt.title("LBFGS")
    129. plt.plot(epoch_list,loss_LBFGS)
    130. plt.ylabel('cost')
    131. plt.xlabel('epoch')
    132. plt.subplot(247)
    133. plt.title("RMSprop")
    134. plt.plot(epoch_list,loss_RMSprop)
    135. plt.ylabel('cost')
    136. plt.xlabel('epoch')
    137. plt.subplot(248)
    138. plt.title("Rprop")
    139. plt.plot(epoch_list,loss_Rprop)
    140. plt.ylabel('cost')
    141. plt.xlabel('epoch')
    142. plt.show()

    运行结果:

  • 相关阅读:
    【Docker】Linux网络命名空间
    LCR 136. 删除链表的节点
    面试官:讲讲MySql索引失效的几种情况
    Flutter 使用 device_info_plus 遇到的问题
    Java--Spring之IoC控制反转;基于注解的DI
    基于zynq7100的OV5640摄像头照相机实验,提供工程源码和技术支持
    阿里巴巴Java方向面试题汇总(含答案)
    CISP考试有哪些备考技巧
    Mysql中的数据类型和运算符
    LLM-文本分块(langchain)与向量化(阿里云DashVector)存储,嵌入LLM实践
  • 原文地址:https://blog.csdn.net/qq_46035581/article/details/134388673