利用c++/c#实现mnist手写字符识别,包括模型训练、推理预测,依赖简单,开箱即用,全部通过代码实现,支持二次开发,以及gpu加速。

训练代码
- // train.cpp
- #include
- #include
-
- #include "trainer.hpp"
-
- int main(int argc, char** argv)
- {
- yilecv::Trainer trainer;
- std::string model_file = "lenet_train_test.prototxt";
- std::string solver_file = "lenet_solver.prototxt";
- trainer.Init(model_file, solver_file, 0, "mnist_output");
- trainer.SetMetricBlobName("accuracy", true);
- trainer.Train();
- std::cout << "trainer.BestMetricValue=" << trainer.BestMetricValue() << std::endl;
- }
推理代码
- // predict.cpp
- #include
- #include
- #include
-
- #include "predictor.hpp"
-
- int main(int argc, char** argv)
- {
- yilecv::Predictor predictor;
- std::string model_file = "lenet_deploy.prototxt";
- predictor.Init(model_file, "mnist_output/best.bin", 0);
- predictor.SetNormScaleCoeff({0.00390625,0.00390625,0.00390625});
-
- std::vector<float> out = predictor.Predict("dataset/mnist_images/test/test_0_7.jpg");
-
- std::vector<int> maxN = predictor.PredictMaxN(out, 1);
- for (int i = 0; i < maxN.size(); ++i)
- {
- std::cout << maxN[i] << ":" << out[maxN[i]] << std::endl;
- }
- }
右键项目,选择属性=>c/c++=>常规=>附加包含目录,设置为yilecv库下的include目录
右键项目,选择属性=>链接器=>常规=>附加库目录,设置为yilecv库下的lib目录
右键项目,选择属性=>链接器=>输入=>附加依赖项,设置为yilecv.lib



新建c#项目

添加YileCVSharp.cs文件到项目中
新建源文件,拷贝yilecv/lib目录下的dll和网络模型相关文件到对应bin目录下,选择release/x64,运行即可
- using System;
- using System.Collections.Generic;
- using System.Linq;
- using System.Text;
- using System.Threading.Tasks;
- using YileCVSharp;
- namespace YileCVSharpExample
- {
- class Program
- {
- static void Main(string[] args)
- {
- // 训练
- //Trainer trainer = new Trainer();
- //string model_file = "lenet_train_test.prototxt";
- //string solver_file = "lenet_solver.prototxt";
- //trainer.Init(model_file, solver_file, 0, "mnist_output");
- //trainer.SetMetricBlobName("accuracy", true);
- //trainer.Train();
- //System.Console.WriteLine("trainer.BestMetricValue =" + ": " + trainer.BestMetricValue());
-
-
- // 推理预测
- Predictor predictor = new Predictor();
- predictor.Init("lenet_deploy.prototxt", "mnist_output/best.bin", 0);
- VectorFloat scale = new VectorFloat();
- scale.Add(0.00390625f);
- scale.Add(0.00390625f);
- scale.Add(0.00390625f);
- predictor.SetNormScaleCoeff(scale);
- VectorFloat output = predictor.Predict("mnist_images/test/test_0_7.jpg");
- VectorInt maxN = predictor.PredictMaxN(output, 1);
- for (int i = 0; i < maxN.Capacity; ++i)
- {
- System.Console.WriteLine(maxN[i] + ": " + output[maxN[i]]);
- }
- System.Console.ReadKey();
- }
- }
- }
配置好的c#工程(包含mnist数据):https://download.csdn.net/download/u012594175/89165417
配置好的c++工程(包含mnis数据):https://download.csdn.net/download/u012594175/89165559
交流QQ群
