• Pytorch实现线性回归


    模型y=x*w+b,使用Pytorch实现梯度下降算法,构建线性回归模型

    1. import torch
    2. import sys
    3. #3行1列的二位tensor
    4. x_data=torch.Tensor([[1.0],[2.0],[3.0]])
    5. y_data=torch.Tensor([[2.0],[4.0],[6.0]])
    6. class LinearModel(torch.nn.Module):
    7. def __init__(self, *args, **kwargs) -> None:
    8. # 调用父类的__init__()方法,这是PyTorch的约定,确保子类正确地初始化父类的部分
    9. super(LinearModel,self).__init__()
    10. # 创建一个线性层,输入通道数为1,输出通道数为1。这通常被称为一个线性模型或者全连接层
    11. self.linear=torch.nn.Linear(1,1)
    12. def forward(self,x):
    13. #将输入数据x传递给线性层,并将结果存储在y_pred中。线性层将每个输入值映射到一个输出值。
    14. y_pred=self.linear(x)
    15. return y_pred
    16. model=LinearModel()
    17. # 创建一个均方误差损失函数对象,该对象用于衡量模型预测值与真实值之间的差异
    18. criterion=torch.nn.MSELoss(size_average=False)
    19. # 创建一个随机梯度下降优化器对象,该对象用于更新模型的参数。这里的学习率被设置为0.01
    20. optimizer=torch.optim.SGD(model.parameters(),lr=0.01)
    21. for epoch in range(1000):
    22. # 调用模型的前向传播方法,计算出模型对训练数据的预测值
    23. y_pred=model(x_data)
    24. # 使用损失函数计算出预测值与真实值之间的差异(损失)
    25. loss=criterion(y_pred,y_data)
    26. print(epoch,loss.item())
    27. optimizer.zero_grad()
    28. # 使用反向传播算法计算出每个参数的梯度。这是损失对每个参数的偏导数
    29. loss.backward()
    30. # 根据计算出的梯度对模型参数进行更新(也称为权重更新或参数更新)
    31. optimizer.step()
    32. print("w=",model.linear.weight.item())
    33. print("b=",model.linear.bias.item())
    34. x_test=torch.Tensor([[4.0]])
    35. y_test=model(x_test)
    36. print("y_pred=",y_test.data)
    1. 0 59.847930908203125
    2. 1 26.652393341064453
    3. 2 11.874536514282227
    4. 3 5.295718193054199
    5. 4 2.366877317428589
    6. 5 1.0629061460494995
    7. ...
    8. 998 9.690893421065994e-09
    9. 999 9.508596576779382e-09
    10. w= 1.9999350309371948
    11. b= 0.0001476586185162887
    12. y_pred= tensor([[7.9999]])

  • 相关阅读:
    数据思维总结:
    安装docker并在内安装mysql
    【每日一题Day360】LC1465切割后面积最大的蛋糕 | 贪心
    vue结合echarts时,浏览器报错Initialize failed: invalid dom
    k8s pod详细讲解
    Self-Attention和Multi-Head Attention的详细代码内容(没有原理)
    excel数据透视表
    87、Redis 的 value 所支持的数据类型(String、List、Set、Zset、Hash)---->List相关命令
    xbox game bar无法打开/安装怎么办?
    【机器学习】Metrics: 衡量算法性能的关键指标
  • 原文地址:https://blog.csdn.net/m0_46306264/article/details/134233071