• mindspore-softmax进行鸢尾花多分类模型


    版本:mindspore1.3.0

    代码:

    import os
    # os.environ['DEVICE_ID'] = '6'
    import csv
    import numpy as np

    import mindspore as ms
    from mindspore import nn
    from mindspore import context
    from mindspore import dataset
    from mindspore.train.callback import LossMonitor
    from mindspore.common.api import ms_function
    from mindspore.ops import operations as P

    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
    with open('iris.data') as csv_file:
        data = list(csv.reader(csv_file, delimiter=','))

    label_map = {
        'Iris-setosa': 0,
        'Iris-versicolor': 1,
        'Iris-virginica':2,
    }

    X = np.array([[float(x) for x in s[:-1]] for s in data[:150]], np.float32)
     

    Y = np.array([[label_map[s[-1]]] for s in data[:150]], np.float32)
     

    train_idx = np.random.choice(150, 120, replace=False)
    test_idx = np.array(list(set(range(150)) - set(train_idx)))
    X_train, Y_train = X[train_idx], Y[train_idx]
    X_test, Y_test = X[test_idx], Y[test_idx]
    XY_train = list(zip(X_train, Y_train))
    ds_train = dataset.GeneratorDataset(XY_train, ['x', 'y'])

    ds_train = ds_train.shuffle(buffer_size=80).batch(32, drop_remainder=True)
    XY_test = list(zip(X_test, Y_test))
    ds_test = dataset.GeneratorDataset(XY_test, ['x', 'y'])
    ds_test = ds_test.batch(30)

    net = nn.Dense(4, 3)
    loss = nn.loss.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
    opt = nn.optim.Momentum(net.trainable_params(), learning_rate=0.05, momentum=0.9)

    model = ms.train.Model(net, loss, opt, metrics={'acc', 'loss'})
    model.train(15, ds_train, callbacks=[LossMonitor(per_print_times=ds_train.get_dataset_size())], dataset_sink_mode=False)
    metrics = model.eval(ds_test)
    print(metrics)

    1. 按照报错提示,是因为你的dataset对象给多个model使用了。

    2. 但是我们拿了你上面的脚本,先从脚本上看,没有发现ds_train / ds_eval给多个model使用的情况,另:在本地运行后,也没有报你上面的报错。

    故:你可以再试下你上条评论里的脚本,或者还有没有其他信息提供?我们再分析下。

  • 相关阅读:
    Linux之查看so/bin依赖(三十一)
    性能测试 Linux 环境下模拟延时和丢包实现
    说说HBase读、写流程
    [C国演义] 第十六章
    Web3行业人才需求激增,加密初创企业的薪资究竟如何?
    第四章 文件管理 十二、虚拟文件系统
    Kotlin协程分析(三)——理解协程上下文
    ARCGIS之成片区开发方案报备坐标txt格式批量导出工具(定制开发版)
    极限多标签学习之SwiftXML
    Android硬件服务访问(2):Driver
  • 原文地址:https://blog.csdn.net/weixin_45666880/article/details/126409739