• 昇思25天学习打卡营第1天 | 快速入门


    内容介绍:通过MindSpore的API来快速实现一个简单的深度学习模型。

    具体内容:

    1. 导包

    1. import mindspore
    2. from mindspore import nn
    3. from mindspore.dataset import vision, transforms
    4. from mindspore.dataset import MnistDataset

    2. 处理数据

    1. from download import download
    2. url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
    3. "notebook/datasets/MNIST_Data.zip"
    4. path = download(url, "./", kind="zip", replace=True)

    3. 获取数据对象

    1. train_dataset = MnistDataset('MNIST_Data/train')
    2. test_dataset = MnistDataset('MNIST_Data/test')

    4. 数据处理

    1. def datapipe(dataset, batch_size):
    2. image_transforms = [
    3. vision.Rescale(1.0 / 255.0, 0),
    4. vision.Normalize(mean=(0.1307,), std=(0.3081,)),
    5. vision.HWC2CHW()
    6. ]
    7. label_transform = transforms.TypeCast(mindspore.int32)
    8. dataset = dataset.map(image_transforms, 'image')
    9. dataset = dataset.map(label_transform, 'label')
    10. dataset = dataset.batch(batch_size)
    11. return dataset
    12. train_dataset = datapipe(train_dataset, 64)
    13. test_dataset = datapipe(test_dataset, 64)

    5. 使用 create_dict_iterator或create_dict_iterator对数据集进行迭代访问

    1. for image, label in test_dataset.create_tuple_iterator():
    2. print(f"Shape of image [N, C, H, W]: {image.shape} {image.dtype}")
    3. print(f"Shape of label: {label.shape} {label.dtype}")
    4. break
    5. for data in test_dataset.create_dict_iterator():
    6. print(f"Shape of image [N, C, H, W]: {data['image'].shape} {data['image'].dtype}")
    7. print(f"Shape of label: {data['label'].shape} {data['label'].dtype}")
    8. break

    6. 网络构建

    1. class Network(nn.Cell):
    2. def __init__(self):
    3. super().__init__()
    4. self.flatten = nn.Flatten()
    5. self.dense_relu_sequential = nn.SequentialCell(
    6. nn.Dense(28*28, 512),
    7. nn.ReLU(),
    8. nn.Dense(512, 512),
    9. nn.ReLU(),
    10. nn.Dense(512, 10)
    11. )
    12. def construct(self, x):
    13. x = self.flatten(x)
    14. logits = self.dense_relu_sequential(x)
    15. return logits
    16. model = Network()
    17. print(model)

    7. 模型训练

    1. loss_fn = nn.CrossEntropyLoss()
    2. optimizer = nn.SGD(model.trainable_params(), 1e-2)
    3. # 1. Define forward function
    4. def forward_fn(data, label):
    5. logits = model(data)
    6. loss = loss_fn(logits, label)
    7. return loss, logits
    8. # 2. Get gradient function
    9. grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
    10. # 3. Define function of one-step training
    11. def train_step(data, label):
    12. (loss, _), grads = grad_fn(data, label)
    13. optimizer(grads)
    14. return loss
    15. def train(model, dataset):
    16. size = dataset.get_dataset_size()
    17. model.set_train()
    18. for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
    19. loss = train_step(data, label)
    20. if batch % 100 == 0:
    21. loss, current = loss.asnumpy(), batch
    22. print(f"loss: {loss:>7f} [{current:>3d}/{size:>3d}]")

    8. 测试函数

    1. def test(model, dataset, loss_fn):
    2. num_batches = dataset.get_dataset_size()
    3. model.set_train(False)
    4. total, test_loss, correct = 0, 0, 0
    5. for data, label in dataset.create_tuple_iterator():
    6. pred = model(data)
    7. total += len(data)
    8. test_loss += loss_fn(pred, label).asnumpy()
    9. correct += (pred.argmax(1) == label).asnumpy().sum()
    10. test_loss /= num_batches
    11. correct /= total
    12. print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

    9. 训练过程

    1. epochs = 3
    2. for t in range(epochs):
    3. print(f"Epoch {t+1}\n-------------------------------")
    4. train(model, train_dataset)
    5. test(model, test_dataset, loss_fn)
    6. print("Done!")

    通过训练可以看出loss不断降低,Accuracy不断升高,可以通过调参到达更好的效果。

    10. 保存模型

    1. mindspore.save_checkpoint(model, "model.ckpt")
    2. print("Saved Model to model.ckpt")

    11. 加载模型

    1. model = Network()
    2. param_dict = mindspore.load_checkpoint("model.ckpt")
    3. param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
    4. print(param_not_load)

  • 相关阅读:
    从零开始学JAVA(05):面向对象编程--01
    四则运算Java版
    COCI2021-2022#1 Logičari
    Mysql(索引)
    XTU-OJ 1178-Rectangle
    Qt配置OpenCV(保姆级教程)
    BUUCTF wireshark 1
    vue 日期控件 100天内的时间禁用不允许选择
    【小程序】-(小撒)
    论文解读(CDCL)《Cross-domain Contrastive Learning for Unsupervised Domain Adaptation》
  • 原文地址:https://blog.csdn.net/weixin_44144773/article/details/139800541