• LibTorch之网络模型构建


    LibTorch之网络模型构建

    线性层

    torch::nn::Linear ln1{nullptr};
    ln1 = register_module("ln", torch::nn::Linear(torch::nn::LinearOptions(24, 1)));
    
    • 1
    • 2

    BatchNorm1d层

    torch::nn::BatchNorm1d bn1{nullptr};
    bn1 = register_module("bn", torch::nn::BatchNorm1d(10));
    
    • 1
    • 2

    官方示例:

    #include 
    // Use one of many "standard library" modules.
    torch::nn::Linear fc1{nullptr}, fc2{nullptr}, fc3{nullptr};
     
    // Define a new Module.
    struct Net : torch::nn::Module {
      Net() {
        // Construct and register two Linear submodules.
        fc1 = register_module("fc1", torch::nn::Linear(784, 64));
        fc2 = register_module("fc2", torch::nn::Linear(64, 32));
        fc3 = register_module("fc3", torch::nn::Linear(32, 10));
      }
    
      // Implement the Net's algorithm.
      torch::Tensor forward(torch::Tensor x) {
        // Use one of many tensor manipulation functions.
        x = torch::relu(fc1->forward(x.reshape({x.size(0), 784})));
        x = torch::dropout(x, /*p=*/0.5, /*train=*/is_training());
        x = torch::relu(fc2->forward(x));
        x = torch::log_softmax(fc3->forward(x), /*dim=*/1);
        return x;
      }
    
    
    };
    
    int main() {
      // Create a new Net.
      auto net = std::make_shared<Net>();
    
      // Create a multi-threaded data loader for the MNIST dataset.
      auto data_loader = torch::data::make_data_loader(
          torch::data::datasets::MNIST("./data").map(
              torch::data::transforms::Stack<>()),
          /*batch_size=*/64);
    
      // Instantiate an SGD optimization algorithm to update our Net's parameters.
      torch::optim::SGD optimizer(net->parameters(), /*lr=*/0.01);
    
      for (size_t epoch = 1; epoch <= 10; ++epoch) {
        size_t batch_index = 0;
        // Iterate the data loader to yield batches from the dataset.
        for (auto& batch : *data_loader) {
          // Reset gradients.
          optimizer.zero_grad();
          // Execute the model on the input data.
          torch::Tensor prediction = net->forward(batch.data);
          // Compute a loss value to judge the prediction of our model.
          torch::Tensor loss = torch::nll_loss(prediction, batch.target);
          // Compute gradients of the loss w.r.t. the parameters of our model.
          loss.backward();
          // Update the parameters based on the calculated gradients.
          optimizer.step();
          // Output the loss and checkpoint every 100 batches.
          if (++batch_index % 100 == 0) {
            std::cout << "Epoch: " << epoch << " | Batch: " << batch_index
                      << " | Loss: " << loss.item<float>() << std::endl;
            // Serialize your model periodically as a checkpoint.
            torch::save(net, "net.pt");
          }
        }
      }
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    #include 
    
    //  定义模型
    struct Net :torch::nn::Module {
    	Net() {
    		// 构造和注册两个线性子模块
    		fc1 = register_module("fc1",torch::nn::Linear(784,64));
    		fc2 = register_module("fc2",torch::nn::Linear(64,32));
    		fc3 = register_module("fc3",torch::nn::Linear(32,10));
    
    	}
    
    	// 实现网络的前向传播
    	torch::Tensor forward() {
    
    	}
    };
    
    int main(int argc, char** argv) {
    
    
    	return 0;
    }
    
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
  • 相关阅读:
    Git Commit Message 规范实践
    从0到1实现 OpenTiny 组件库跨框架技术
    Opencv学习笔记-第1篇 读显存图
    C++学习路线(二十五)
    论文阅读_深度学习的医疗异常检测综述
    springboot+基于Java的果蔬产品销售系统 毕业设计-附源码131110
    网络面试-ox09 http是如何维持用户的状态?
    .NET性能优化-你应该为集合类型设置初始大小
    17. 电话号码的字母组合
    公司新招了一个00后软件测试工程师,上来一顿操作给我看呆了...
  • 原文地址:https://blog.csdn.net/qq_41375318/article/details/126274285