• 基于CNTK实现迁移学习-图像分类【附部分源码】



    前言

    本文主要实现基于cntk实现迁移学习,以图像分类为例,利用ResNet模型


    一、什么是迁移学习

    通俗的讲就是站在巨人的肩膀上学习,利用已经训练的比较好的(图像特征提取能力比较好的)模型,根据自定义的数据集,初略修改模型结构,基于之前的权值继续训练,这样做的好处是特征提取能力比较强,能加快训练。

    二、实现方式

    首先本人之前的一篇关于CNTK文章:基于CNTK/C#实现图像分类,此方法就是使用了迁移学习的方法。

    1.预训练模型

    本文使用的是CNTK内置的网络结构,网络结构模型如下,可自行免费下载:

    2.代码实现

    针对对C#有一定基础的同学

    1.变量定义

     //迁移学习网络结构的输入层和输出层名称,用于自定义修改网络的输入输出结构
    private static string featureNodeName = "features";
    private static string lastHiddenNodeName = "z.x";
    private static string predictionNodeName = "prediction";
    private static string pre_Model = "./PreModel/ResNet50_ImageNet_CNTK.model";
    //训练的图像的参数
    private static string ImageDir_Train = @"./DataSet_Classification_Chess\DataImage";
    private static string ImageDir_Test = @"./DataSet_Classification_Chess/test";
    private static int[] imageDims = new int[] { 224, 224, 3 };
    private static string[] classes_names = new string[] { };
    private static int TrainNum = 400;
    private static int batch_size = 4;
    private static float learning_rate = 0.0001F;
    private static float momentum = 0.9F;
    private static float l2RegularizationWeight = 0.05F;
    private static string model_path = "./resultModel";
    private static bool useGPU = true;
    private static string ext = "bmp";
    private static DeviceDescriptor device = useGPU ? DeviceDescriptor.GPUDevice(0) : DeviceDescriptor.CPUDevice;
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19

    2.网络构建

    这里根据节点名称进行修改:

    private static Function CreateTransferLearningModel(string baseModelFile, int numClasses, DeviceDescriptor device, out Variable imageInput, out Variable labelInput, out Function trainingLoss, out Function predictionError)
    {
        Function baseModel = Function.Load(baseModelFile, device);
    
        imageInput = Variable.InputVariable(imageDims, DataType.Float);
        labelInput = Variable.InputVariable(new int[] { numClasses }, DataType.Float);
        Function normalizedFeatureNode = CNTKLib.Minus(imageInput, Constant.Scalar(DataType.Float, 114.0F));
    
        Variable oldFeatureNode = baseModel.Arguments.Single(a => a.Name == featureNodeName);
        Function lastNode = baseModel.FindByName(lastHiddenNodeName);
    
        Function clonedLayer = CNTKLib.AsComposite(lastNode).Clone(
            ParameterCloningMethod.Freeze,
            new Dictionary<Variable, Variable>() { { oldFeatureNode, normalizedFeatureNode } });
    
        Function clonedModel = Dense(clonedLayer, numClasses, device, Activation.None, predictionNodeName);
    
        trainingLoss = CNTKLib.CrossEntropyWithSoftmax(clonedModel, labelInput);
        predictionError = CNTKLib.ClassificationError(clonedModel, labelInput);
    
        return clonedModel;
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22

    3.网络训练/验证/测试

    参考上边提供的文章,这里对每行代码不做过多解释,直接上main函数的实现方法

    classes_names = CreateDataList(ImageDir_Train, 0.9, Path.Combine(model_path, "train_data.txt"), Path.Combine(model_path, "val_data.txt"));
                
    MinibatchSource minibatchSource = CreateMinibatchSource(Path.Combine(model_path, "train_data.txt"), imageDims, classes_names.Length);
    
    //网络结构迁移
    Variable imageInput, labelInput;
    Function trainingLoss, predictionError;
    Function transferLearningModel = CreateTransferLearningModel(pre_Model, classes_names.Length, device, out imageInput, out labelInput, out trainingLoss, out predictionError);
    
    //学习率设置
    AdditionalLearningOptions additionalLearningOptions = new AdditionalLearningOptions() { l2RegularizationWeight = l2RegularizationWeight };
    IList<Learner> parameterLearners = new List<Learner>() {
        Learner.MomentumSGDLearner(transferLearningModel.Parameters(),
        new TrainingParameterScheduleDouble(learning_rate, 0),
        new TrainingParameterScheduleDouble(momentum, 0),
        true,
        additionalLearningOptions)};
    
    //获得训练器
    var trainer = Trainer.CreateTrainer(transferLearningModel, trainingLoss, predictionError, parameterLearners);
    
    //模型训练
    int outputFrequencyInMinibatches = 1; 
    int data_length = readFileLines(Path.Combine(model_path, "train_data.txt"));
    TrainNum = Convert.ToInt32(TrainNum * data_length / batch_size);
    for (int minibatchCount = 0; minibatchCount < TrainNum; ++minibatchCount)
    {
        var minibatchData = minibatchSource.GetNextMinibatch((uint)batch_size, device);
    
        trainer.TrainMinibatch(new Dictionary<Variable, MinibatchData>()
        {
            { imageInput, minibatchData[minibatchSource.StreamInfo("image")] },
            { labelInput, minibatchData[minibatchSource.StreamInfo("labels")] } 
        }, device);
        PrintTrainingProgress(trainer, minibatchCount, TrainNum, outputFrequencyInMinibatches);
    }
    
    //模型保存
    transferLearningModel.Save(Path.Combine(model_path, "Ctu_Classification.model"));
    
    Console.ReadLine();
    
    //模型验证
    ValidateModelWithMinibatchSource(Path.Combine(model_path, "Ctu_Classification.model"), Path.Combine(model_path, "val_data.txt"), imageDims, classes_names.Length, device);
    
    Console.ReadLine();
    
    //模型预测
    Function model = Function.Load(Path.Combine(model_path, "Ctu_Classification.model"), device);
    string[] all_image = Directory.GetFiles(ImageDir_Test, $"*.{ext}");
    foreach (string file in all_image)
    {
        var inputValue = new Value(new NDArrayView(imageDims, Load(imageDims[0], imageDims[1], file), device));
        var inputDataMap = new Dictionary<Variable, Value>() { { model.Arguments[0], inputValue } };
        var outputDataMap = new Dictionary<Variable, Value>() {
                { model.Output, null }
            };
        model.Evaluate(inputDataMap, outputDataMap, device);
        var outputData = outputDataMap[model.Output].GetDenseData<float>(model.Output).First();
    
        var output = outputData.Select(x => (double)x).ToArray();
        var classIndex = Array.IndexOf(output, output.Max());
        var className = classes_names[classIndex];
        Console.WriteLine(file + " : " + className);
    }
    
    Console.ReadLine();
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67

    三、效果展示

    在这里插入图片描述

    四、CNTK网络结构可视化(ResNet50

    在这里插入图片描述

  • 相关阅读:
    使用ARIMA进行时间序列预测|就代码而言
    防火墙基本概念
    LLC谐振变换器软启动过程分析与问题处理
    STM32-03基于HAL库(CubeMX+MDK+Proteus)输入检测案例(按键控制LED)
    单词接龙 II
    ESP32网络开发实例-TCP服务器数据传输
    SpringBoot发送邮件(SpringBoot整合JavaMail)
    使用SimPowerSystems并网光伏阵列研究(Simulink实现)
    8 种 Python 定时任务的解决方案
    Hystrix熔断器整合 - 服务熔断和服务降级
  • 原文地址:https://blog.csdn.net/ctu_sue/article/details/127652808