• MindSpore基础教程:LeNet-5 神经网络在MindSpore中的实现与训练


    MindSpore基础教程:LeNet-5 神经网络在MindSpore中的实现与训练

    官方文档教程使用已经弃用的MindVision模块,本文是对官方文档的更新
    深度学习在图像识别领域取得了显著的成功,LeNet-5 作为卷积神经网络的经典之作,在诸多研究和应用中占有重要地位。本文将详细介绍如何使用 MindSpore 框架实现并训练一个 LeNet-5 神经网络,专注于处理MNIST手写数字数据集。

    前言

    MindSpore 是华为推出的一种新型深度学习框架,旨在为用户提供高效、易用的编程体验。接下来,我们将通过实例来展示如何在 MindSpore 中构建、训练和评估一个经典的 LeNet-5 神经网络。

    环境配置

    MindSpore官网

    LeNet-5 网络结构简介

    LeNet-5 是一个简单的卷积神经网络,包含两个卷积层和三个全连接层。它经常被用于图像识别任务,特别是在处理像 MNIST 这样的手写数字数据集时表现出色。

    数据集准备与预处理

    首先,我们需要准备并预处理数据集。在这个例子中,我们将使用 MNIST 数据集。以下函数 create_dataset 负责加载数据集,并进行必要的预处理:

    def create_dataset(data_path, batch_size=32, repeat_size=1):
        """
        创建用于训练的MNIST数据集。
    
        此函数负责加载MNIST数据集,对数据进行预处理和转换,以便它们可以用于训练神经网络。数据预处理包括调整图像大小、重新缩放和类型转换。
    
        参数:
            data_path (str): MNIST数据集的路径。这应该是包含MNIST数据文件的目录路径。
            batch_size (int, 可选): 每个数据批次的大小。默认值为32。
            repeat_size (int, 可选): 数据集重复的次数。这用于增加数据集的大小。默认值为1。
    
        步骤:
            1. 加载MNIST数据集。
            2. 对图像执行大小调整操作,将图像大小统一调整为32x32像素。
            3. 对图像进行重新缩放和标准化处理。先将像素值缩放到0-1之间,然后进行标准化。
            4. 将图像的格式从高宽通道(HWC)转换为通道高宽(CHW)。
            5. 对标签进行类型转换,将其转换为整型(int32)。
            6. 对数据集进行洗牌、批处理和重复操作,以准备训练过程。
    
        返回:
            返回一个处理过的MNIST数据集,可以直接用于模型训练。
    
        注意:
            - 数据集的预处理步骤对于训练深度学习模型来说是非常重要的,它们会影响训练的效果和速度。
            - 调整batch_size和repeat_size可以影响模型训练时的内存消耗和速度。
        """
        mnist_dataset = ds.MnistDataset(data_path)
    
        resize_operation = vision.Resize((32, 32), interpolation=Inter.LINEAR)
        rescale_normalization_op = vision.Rescale(1 / 0.3081, -1 * 0.1307 / 0.3081)
        rescale_op = vision.Rescale(1.0 / 255.0, 0.0)
        hwc_to_chw_op = vision.HWC2CHW()
        type_cast_op = transforms.TypeCast(mstype.int32)
    
        mnist_dataset = mnist_dataset.map(input_columns="label", operations=type_cast_op)
        mnist_dataset = mnist_dataset.map(input_columns="image",
                                          operations=[resize_operation, rescale_op, rescale_normalization_op,
                                                      hwc_to_chw_op])
        mnist_dataset = mnist_dataset.shuffle(buffer_size=10000)
        mnist_dataset = mnist_dataset.batch(batch_size, drop_remainder=True)
        mnist_dataset = mnist_dataset.repeat(repeat_size)
    
        return mnist_dataset
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43

    这个函数将数据集中的图像调整为统一的大小,并进行重新缩放和标准化。

    构建 LeNet-5 模型

    LeNet-5 模型的构建在 LeNet5 类中实现。此类定义了网络的各层及其排列:

    class LeNet5(nn.Cell):
        """
        LeNet-5 神经网络结构。
    
        这是一个经典的卷积神经网络,通常用于图像识别任务。它包含了两个卷积层和三个全连接层。
    
        参数:
            num_class (int): 输出层的类别数量。默认为10,适用于MNIST数据集。
            num_channel (int): 输入图像的通道数。对于灰度图像,此值为1。
    
        组件:
            - conv1: 第一个卷积层,使用有效填充。
            - conv2: 第二个卷积层,同样使用有效填充。
            - fc1: 第一个全连接层。
            - fc2: 第二个全连接层。
            - fc3: 第三个全连接层,输出层。
            - relu: 激活函数,使用ReLU。
            - max_pool2d: 最大池化层。
            - flatten: 扁平化层,用于全连接层之前的数据转换。
    
        方法:
            - construct(x): 定义了前向传播的过程。
        """
    
        def __init__(self, num_class=10, num_channel=1):
            super(LeNet5, self).__init__()
            self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
            self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
            self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
            self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
            self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
            self.relu = nn.ReLU()
            self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
            self.flatten = nn.Flatten()
    
        def construct(self, x):
            x = self.conv1(x)
            x = self.relu(x)
            x = self.max_pool2d(x)
            x = self.conv2(x)
            x = self.relu(x)
            x = self.max_pool2d(x)
            x = self.flatten(x)
            x = self.fc1(x)
            x = self.relu(x)
            x = self.fc2(x)
            x = self.relu(x)
            x = self.fc3(x)
            return x
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50

    训练模型

    接下来,我们定义 train_network 函数来训练模型。此函数接受模型实例、数据集路径和其他训练参数:

    def train_network(model, epoch_size, data_path, repeat_size, checkpoint_callback):
        """
        训练神经网络模型。
    
        此函数负责初始化数据集,然后使用指定的模型进行训练。在训练过程中,它将记录损失并保存模型的检查点。
    
        参数:
            model (Model): 要训练的神经网络模型。
            epoch_size (int): 训练过程中遍历数据集的次数。
            data_path (str): 训练数据集的路径。
            repeat_size (int): 数据集的重复次数,用于扩充数据集。
            checkpoint_callback (Callback): 用于保存模型检查点的回调函数。
    
        过程:
            - 使用 `create_dataset` 函数创建训练数据集。
            - 调用模型的 `train` 方法进行训练。
            - 在训练过程中,会通过回调函数记录损失和保存检查点。
    
        注意:
            - 确保提供的 `data_path` 包含适当格式的数据。
        """
        print("============== 开始训练 ==============")
        ds_train = create_dataset(data_path, 32, repeat_size)
        model.train(epoch_size, ds_train, callbacks=[checkpoint_callback, LossMonitor(), TimeMonitor()],
                    dataset_sink_mode=False)
        print("============== 训练结束 ==============")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26

    主函数

    最后,我们通过 train 函数和 parse_arguments 函数将所有步骤串联起来。train 函数负责初始化模型、损失函数、优化器和检查点回调,然后调用 train_network 进行训练:

    def train(args):
        """
        初始化并训练LeNet-5神经网络模型。
    
        此函数设置了网络模型、损失函数、优化器,并定义了模型检查点。然后,使用指定的参数调用 `train_network` 函数来进行模型的训练。
    
        参数:
            args (Namespace): 一个包含训练参数的命名空间对象。此对象应该包含以下属性:
                - epochs (int): 模型训练的迭代次数。
                - data_url (str): 训练数据集的路径。
                - output_path (str): 保存模型检查点的路径。
    
        过程:
            1. 创建 LeNet-5 网络实例。
            2. 定义损失函数为 Softmax Cross-Entropy。
            3. 定义优化器为 Momentum 优化器。
            4. 创建模型实例,并指定网络、损失函数、优化器和评估指标。
            5. 设置模型检查点配置。
            6. 初始化模型检查点回调函数。
            7. 调用 `train_network` 函数进行训练。
    
        注意:
            - 确保 `args` 对象包含正确和完整的训练参数。
            - 调整优化器和损失函数的参数可以对训练结果产生影响。
            - 模型检查点将保存在 `args.output_path` 指定的路径中。
        """
        net = LeNet5()
        net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
        net_opt = nn.Momentum(net.trainable_params(), 0.01, 0.9)
    
        model = Model(net, net_loss, net_opt, metrics={"Accuracy": nn.Accuracy()})
    
        config_checkpoint = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)
        checkpoint_callback = ModelCheckpoint(prefix="checkpoint_lenet", directory=args.output_path,
                                              config=config_checkpoint)
    
        train_network(model, args.epochs, args.data_url, 1, checkpoint_callback)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37

    推理

    # 加载网络
    param_dict = load_checkpoint("/root/MyCode/pycharm/lenet5/ckpt/checkpoint_lenet-19_1884.ckpt")
    network = LeNet5(num_class=NUM_CLASS, num_channel=1)  # 用您定义的LeNet5类创建模型实例
    load_param_into_net(network, param_dict)  # 将参数加载到网络中
    model = Model(network)
    
    
    def predict_digit(img):
        # 图像预处理
        img = cv2.resize(img, (32, 32))  # 调整图像大小为32x32
        img = np.array(img, dtype=np.float32)  # 转换图像数据类型
        img = (img - 0.1307) / 0.3081  # 对图像进行标准化处理
        img = img[np.newaxis, np.newaxis, :, :]  # 改变图像形状以符合网络输入要求(1, 1, 32, 32)
    
        # 将图像数据转换为MindSpore张量
        img_tensor = Tensor(img)
    
        # 使用模型进行预测
        output = model.predict(img_tensor)
    
        # 将输出转换为概率分布
        probabilities = Softmax()(output)
    
        # 获取每个类别的概率
        probabilities_np = probabilities.asnumpy()[0]
    
        # 将概率转换为字典格式
        labels = [str(i) for i in range(10)]  # 类别标签,例如"0", "1", "2", ..., "9"
        probabilities_dict = {label: prob for label, prob in zip(labels, probabilities_np)}
    
        return probabilities_dict
    
    
    gr.Interface(
        fn=predict_digit,
        inputs=gr.Image(image_mode='L'),
        outputs=gr.Label(num_top_classes=NUM_CLASS),
        live=False,
        css=".footer {display:none !important}",
        title="0-9数字画板",
        description="画0-9数字",
        thumbnail="https://raw.githubusercontent.com/gradio-app/real-time-mnist/master/thumbnail2.png"
    ).launch()
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44

    结论

    通过本文的指南,您可以在 MindSpore 框架中实现并训练一个经典的 LeNet-5 神经网络。LeNet-5 在图像识别任务中展现了卓越的性能,而 MindSpore 的高效和易用性使得深度学习研究和开发更加便捷。您可以根据本文的指导进行实验,并根据需要调整网络结构和训练参数。

  • 相关阅读:
    STM32WL开发之易智联LORA评估板上按键KEY的配置与应用
    【面试高高手】 —— Java集合篇(23题)
    MySQL:连接查询 | 内连接,外连接
    Ubuntu 和 Windows 文件互传
    SH-CST 2022丨SpeechHome 语音技术研讨会
    数据结构与算法系列一之整数、数组及字符串
    微软商店无法访问
    Qt --- Day03
    @KafkaListener注解详解(一)| 常用参数详解
    用JMeter对HTTP接口进行压测(一)压测脚本的书写、调试思路
  • 原文地址:https://blog.csdn.net/qq_42896106/article/details/134510338