• LibTorch之网络模型构建


    LibTorch之网络模型构建

    线性层

    torch::nn::Linear ln1{nullptr};
    ln1 = register_module("ln", torch::nn::Linear(torch::nn::LinearOptions(24, 1)));
    
    • 1
    • 2

    BatchNorm1d层

    torch::nn::BatchNorm1d bn1{nullptr};
    bn1 = register_module("bn", torch::nn::BatchNorm1d(10));
    
    • 1
    • 2

    官方示例:

    #include 
    // Use one of many "standard library" modules.
    torch::nn::Linear fc1{nullptr}, fc2{nullptr}, fc3{nullptr};
     
    // Define a new Module.
    struct Net : torch::nn::Module {
      Net() {
        // Construct and register two Linear submodules.
        fc1 = register_module("fc1", torch::nn::Linear(784, 64));
        fc2 = register_module("fc2", torch::nn::Linear(64, 32));
        fc3 = register_module("fc3", torch::nn::Linear(32, 10));
      }
    
      // Implement the Net's algorithm.
      torch::Tensor forward(torch::Tensor x) {
        // Use one of many tensor manipulation functions.
        x = torch::relu(fc1->forward(x.reshape({x.size(0), 784})));
        x = torch::dropout(x, /*p=*/0.5, /*train=*/is_training());
        x = torch::relu(fc2->forward(x));
        x = torch::log_softmax(fc3->forward(x), /*dim=*/1);
        return x;
      }
    
    
    };
    
    int main() {
      // Create a new Net.
      auto net = std::make_shared<Net>();
    
      // Create a multi-threaded data loader for the MNIST dataset.
      auto data_loader = torch::data::make_data_loader(
          torch::data::datasets::MNIST("./data").map(
              torch::data::transforms::Stack<>()),
          /*batch_size=*/64);
    
      // Instantiate an SGD optimization algorithm to update our Net's parameters.
      torch::optim::SGD optimizer(net->parameters(), /*lr=*/0.01);
    
      for (size_t epoch = 1; epoch <= 10; ++epoch) {
        size_t batch_index = 0;
        // Iterate the data loader to yield batches from the dataset.
        for (auto& batch : *data_loader) {
          // Reset gradients.
          optimizer.zero_grad();
          // Execute the model on the input data.
          torch::Tensor prediction = net->forward(batch.data);
          // Compute a loss value to judge the prediction of our model.
          torch::Tensor loss = torch::nll_loss(prediction, batch.target);
          // Compute gradients of the loss w.r.t. the parameters of our model.
          loss.backward();
          // Update the parameters based on the calculated gradients.
          optimizer.step();
          // Output the loss and checkpoint every 100 batches.
          if (++batch_index % 100 == 0) {
            std::cout << "Epoch: " << epoch << " | Batch: " << batch_index
                      << " | Loss: " << loss.item<float>() << std::endl;
            // Serialize your model periodically as a checkpoint.
            torch::save(net, "net.pt");
          }
        }
      }
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    #include 
    
    //  定义模型
    struct Net :torch::nn::Module {
    	Net() {
    		// 构造和注册两个线性子模块
    		fc1 = register_module("fc1",torch::nn::Linear(784,64));
    		fc2 = register_module("fc2",torch::nn::Linear(64,32));
    		fc3 = register_module("fc3",torch::nn::Linear(32,10));
    
    	}
    
    	// 实现网络的前向传播
    	torch::Tensor forward() {
    
    	}
    };
    
    int main(int argc, char** argv) {
    
    
    	return 0;
    }
    
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
  • 相关阅读:
    JSON & XML
    java-php-python-ssm网上游戏商店设计计算机毕业设计
    《深度学习进阶:自然语言处理》读书笔记:第7章 基于RNN生成文本
    【web-渗透测试方法】(15.2)分析应用程序、测试客户端控件
    Python 图形化界面基础篇:使用框架( Frame )组织界面
    3.4 常用操作
    分类预测 | Matlab实现CNN-BiLSTM-SAM-Attention卷积双向长短期记忆神经网络融合空间注意力机制的数据分类预测
    【进阶玩法】策略+责任链+组合实现合同签章
    Web安全——Web安全漏洞与利用上篇(仅供学习)
    【软考软件评测师】2020综合知识历年真题
  • 原文地址:https://blog.csdn.net/qq_41375318/article/details/126274285