目录
(1)实例化网络,该网络模型必须继承torch::nn::Module类;
(2)获取训练数据,其中输入数据维度和目标数据维度都是(b,n);
(3)实例化优化器,这里用的是Adam,学习率是0.0005;
(4)forward,在求mse_loss,backward,step。
- #include
- #include
- #include"mlp.h"
-
- int main()
- {
- MLP mlp(20, 2); // model
- at::Tensor input_x = torch::rand({ 4, 20 }); // data
- at::Tensor input_y = torch::ones({ 4, 2 });
- torch::optim::Adam optimizer(mlp.parameters(), 0.001);
- for (int epoch = 0; epoch < 100; epoch++) // train
- {
- optimizer.zero_grad();
- at::Tensor output = mlp.forward(input_x);
- at::Tensor loss = torch::mse_loss(output, input_y);
- loss.backward();
- optimizer.step();
- std::cout << loss.item().toFloat() << std::endl;
- }
- }
out shape: [2, 1]; target shape:[2, 1]
MLP多层感知机,也就是全连接网络.
- #ifndef MLP_H
- #define MLP_H
- #endif // MLP_H
-
- #include
- #include
-
- // 小模块:fc+bn+relu
- class LinearBnReluImpl : public torch::nn::Module {
- public:
- LinearBnReluImpl(int intput_features, int output_features);
- torch::Tensor forward(torch::Tensor x);
- private:
- //layers
- torch::nn::Linear ln{ nullptr }; // 定义私有成员,先构造函数初始化,再在forward函数使用。
- torch::nn::BatchNorm1d bn{ nullptr };
- };
- TORCH_MODULE(LinearBnRelu);
-
-
- class MLP : public torch::nn::Module {
- public:
- MLP(int in_features, int out_features); // 构造函数:输入特征维度,和输出特征维度
- torch::Tensor forward(torch::Tensor x); // 推理函数
- private:
- int mid_features[3] = { 32,64,128 }; // 中间层特征维度
- LinearBnRelu ln1{ nullptr }; // 3个(linear + bn + relu)
- LinearBnRelu ln2{ nullptr };
- LinearBnRelu ln3{ nullptr };
- torch::nn::Linear out_ln{ nullptr }; // 普通的linear层
- };
3个linear+bn+relu,最后接一个linear.
- #include "mlp.h"
-
- // 实现LinearBnRelu
- // 注册线性层、bn层
- LinearBnReluImpl::LinearBnReluImpl(int in_features, int out_features) {
- ln = register_module("ln", torch::nn::Linear(torch::nn::LinearOptions(in_features, out_features)));
- // 注意bn操作时,训练时batch_size必须设置成大于1,否则没意义且会报错,测试时会屏蔽此操作
- bn = register_module("bn", torch::nn::BatchNorm1d(out_features));
- }
- // linear->relu->bn
- torch::Tensor LinearBnReluImpl::forward(torch::Tensor x) {
- x = torch::relu(ln->forward(x));
- x = bn(x);
- return x;
- }
-
- MLP::MLP(int in_features, int out_features) {
- ln1 = LinearBnRelu(in_features, mid_features[0]); // 初始化
- ln2 = LinearBnRelu(mid_features[0], mid_features[1]);
- ln3 = LinearBnRelu(mid_features[1], mid_features[2]);
- out_ln = torch::nn::Linear(mid_features[2], out_features);
-
- ln1 = register_module("ln1", ln1); // 构造函数注册轮子
- ln2 = register_module("ln2", ln2);
- ln3 = register_module("ln3", ln3);
- out_ln = register_module("out_ln", out_ln);
- }
-
- torch::Tensor MLP::forward(torch::Tensor x) {
- x = ln1->forward(x); // 逐个forward,因为每个都是module,有各自的forward函数。
- x = ln2->forward(x); //
- x = ln3->forward(x);
- x = out_ln->forward(x);
- return x;
- }
loss逐渐收敛
参考:LibtorchTutorials/lesson3-BasicModels at main · AllentDan/LibtorchTutorials · GitHub