• 使用pytorch实现一个线性回归训练函数


    使用sklearn.dataset 的make_regression创建用于线性回归的数据集

    1. def create_dataset():
    2. x, y, coef = make_regression(n_samples=100, noise=10, coef=True, bias=14.5, n_features=1, random_state=0)
    3. return torch.tensor(x), torch.tensor(y), coef

    加载数据集,并拆分batchs训练集

    1. def load_dataset(x, y, batch_size):
    2. data_len = len(y)
    3. batch_num = data_len // batch_size
    4. for idx in range(batch_num):
    5. start = idx * batch_num
    6. end = idx * batch_num + batch_num
    7. train_x = x[start : end]
    8. train_y = y[start : end]
    9. yield train_x, train_y

    定义初始权重和定义计算函数

    1. w = torch.tensor(0.1, requires_grad=True, dtype=torch.float64)
    2. b = torch.tensor(0, requires_grad=True, dtype=torch.float64)
    3. def linear_regression(x):
    4. return x * w + b

    损失函数使用平方差

    1. def linear_loss(y_pred, y_true):
    2. return (y_pred - y_true) ** 2

    优化参数使用梯度下降方法

    1. def sgd(linear_rate, batch_size):
    2. w.data = w.data - linear_rate * w.grad / batch_size
    3. b.data = b.data - linear_rate * b.grad / batch_size

    训练代码

    1. def train():
    2. # 加载数据
    3. x, y, coef = create_dataset()
    4. data_len = len(y)
    5. # 定义参数
    6. batch_size = 10
    7. epochs = 100
    8. linear_rate = 0.01
    9. # 记录损失值
    10. epochs_loss = []
    11. # 迭代
    12. for eid in range(epochs):
    13. total_loss = 0.0
    14. for train_x, train_y in load_dataset(x, y, batch_size):
    15. # 输入模型
    16. y_pred = linear_regression(train_x)
    17. # 计算损失
    18. loss_num = linear_loss(y_pred, train_y.reshape(-1,1)).sum()
    19. # 梯度清理
    20. if w.grad is not None:
    21. w.grad.zero_()
    22. if b.grad is not None:
    23. b.grad.zero_()
    24. # 反向传播
    25. loss_num.backward()
    26. # 更新权重
    27. sgd(linear_rate, batch_size)
    28. # 统计损失数值
    29. total_loss = total_loss + loss_num.item()
    30. # 记录本次迭代的平均损失
    31. b_loss = total_loss / data_len
    32. epochs_loss.append(b_loss)
    33. print("epoch={},b_loss={}".format(eid, b_loss))
    34. # 显示预测线核真实线的拟合关系
    35. print(w, b)
    36. print(coef, 14.5)
    37. plt.scatter(x, y)
    38. test_x = torch.linspace(x.min(), x.max(), 1000)
    39. y1 = torch.tensor([v * w + b for v in test_x])
    40. y2 = torch.tensor([v * coef + 14.5 for v in test_x])
    41. plt.plot(test_x, y1, label='train')
    42. plt.plot(test_x, y2, label='true')
    43. plt.grid()
    44. plt.show()
    45. # 显示损失值变化曲线
    46. plt.plot(range(epochs), epochs_loss)
    47. plt.show()

    拟合显示还不错

    损失值在低5次迭代后基本就很小了

  • 相关阅读:
    Open3D 点云变换(平移、旋转及尺度)
    人大金仓分析型数据库系统扩容(六)
    记录自签tomcat所用TLS1.2链接所需SSL证书
    GBase 8a MPP集群管理之虚拟集群镜像表
    React:构建Web应用的未来
    集合java
    OO面向对象再认识
    cpp中this和*this区别
    react 跨级举荐通信
    【Redis】Bitmap 使用及应用场景
  • 原文地址:https://blog.csdn.net/qq974816077/article/details/136353461