• libtorch c++ 定义全链接网络


    目录

    1. main函数

    2. 搭建MLP网络

    2.1 mlp.h

    2.2 mlp.cpp

    1. main函数

    (1)实例化网络,该网络模型必须继承torch::nn::Module类;

    (2)获取训练数据,其中输入数据维度和目标数据维度都是(b,n);

    (3)实例化优化器,这里用的是Adam,学习率是0.0005;

    (4)forward,在求mse_loss,backward,step。

    1. #include
    2. #include
    3. #include"mlp.h"
    4. int main()
    5. {
    6. MLP mlp(20, 2); // model
    7. at::Tensor input_x = torch::rand({ 4, 20 }); // data
    8. at::Tensor input_y = torch::ones({ 4, 2 });
    9. torch::optim::Adam optimizer(mlp.parameters(), 0.001);
    10. for (int epoch = 0; epoch < 100; epoch++) // train
    11. {
    12. optimizer.zero_grad();
    13. at::Tensor output = mlp.forward(input_x);
    14. at::Tensor loss = torch::mse_loss(output, input_y);
    15. loss.backward();
    16. optimizer.step();
    17. std::cout << loss.item().toFloat() << std::endl;
    18. }
    19. }

    out shape: [2, 1]; target shape:[2, 1]

    2. 搭建MLP网络

    MLP多层感知机,也就是全连接网络.

    2.1 mlp.h

    1. #ifndef MLP_H
    2. #define MLP_H
    3. #endif // MLP_H
    4. #include
    5. #include
    6. // 小模块:fc+bn+relu
    7. class LinearBnReluImpl : public torch::nn::Module {
    8. public:
    9. LinearBnReluImpl(int intput_features, int output_features);
    10. torch::Tensor forward(torch::Tensor x);
    11. private:
    12. //layers
    13. torch::nn::Linear ln{ nullptr }; // 定义私有成员,先构造函数初始化,再在forward函数使用。
    14. torch::nn::BatchNorm1d bn{ nullptr };
    15. };
    16. TORCH_MODULE(LinearBnRelu);
    17. class MLP : public torch::nn::Module {
    18. public:
    19. MLP(int in_features, int out_features); // 构造函数:输入特征维度,和输出特征维度
    20. torch::Tensor forward(torch::Tensor x); // 推理函数
    21. private:
    22. int mid_features[3] = { 32,64,128 }; // 中间层特征维度
    23. LinearBnRelu ln1{ nullptr }; // 3个(linear + bn + relu)
    24. LinearBnRelu ln2{ nullptr };
    25. LinearBnRelu ln3{ nullptr };
    26. torch::nn::Linear out_ln{ nullptr }; // 普通的linear层
    27. };

    2.2 mlp.cpp

    3个linear+bn+relu,最后接一个linear.

    1. #include "mlp.h"
    2. // 实现LinearBnRelu
    3. // 注册线性层、bn层
    4. LinearBnReluImpl::LinearBnReluImpl(int in_features, int out_features) {
    5. ln = register_module("ln", torch::nn::Linear(torch::nn::LinearOptions(in_features, out_features)));
    6. // 注意bn操作时,训练时batch_size必须设置成大于1,否则没意义且会报错,测试时会屏蔽此操作
    7. bn = register_module("bn", torch::nn::BatchNorm1d(out_features));
    8. }
    9. // linear->relu->bn
    10. torch::Tensor LinearBnReluImpl::forward(torch::Tensor x) {
    11. x = torch::relu(ln->forward(x));
    12. x = bn(x);
    13. return x;
    14. }
    15. MLP::MLP(int in_features, int out_features) {
    16. ln1 = LinearBnRelu(in_features, mid_features[0]); // 初始化
    17. ln2 = LinearBnRelu(mid_features[0], mid_features[1]);
    18. ln3 = LinearBnRelu(mid_features[1], mid_features[2]);
    19. out_ln = torch::nn::Linear(mid_features[2], out_features);
    20. ln1 = register_module("ln1", ln1); // 构造函数注册轮子
    21. ln2 = register_module("ln2", ln2);
    22. ln3 = register_module("ln3", ln3);
    23. out_ln = register_module("out_ln", out_ln);
    24. }
    25. torch::Tensor MLP::forward(torch::Tensor x) {
    26. x = ln1->forward(x); // 逐个forward,因为每个都是module,有各自的forward函数。
    27. x = ln2->forward(x); //
    28. x = ln3->forward(x);
    29. x = out_ln->forward(x);
    30. return x;
    31. }

    loss逐渐收敛

     

    参考:LibtorchTutorials/lesson3-BasicModels at main · AllentDan/LibtorchTutorials · GitHub

  • 相关阅读:
    大语言模型在天猫AI导购助理项目的实践!
    CSS英文单词强制截断换行
    RT-Thread学习笔记(一):认识RT-Thread系统
    js适配文件
    【SQL】之索引
    Docker进阶——再次认识docker的概念 & Docker的结构 & Docker镜像结构 & 镜像的构建方式
    接口测试之文件上传
    7-155 字符转换
    普中精灵开发板stm32烧录程序失败
    Unity与 DLL文件 | Mac中使用 Xcode项目使用C++生成 .dylib文件
  • 原文地址:https://blog.csdn.net/jizhidexiaoming/article/details/127125955