• 【Python深度学习】Python全栈体系(三十四)


    深度学习

    第十五章 数据准备

    一、数据准备

    1. 什么是数据准备?
    • 数据准备是指将样本数据从外部(主要指文件)读入,并且按照一定方式(随机、批量)传递给神经网络,进行训练或测试的过程
    • 数据准备包含三个步骤:
      • 第一步:自定义Reader生成训练、预测数据
      • 第二步:在网络配置中定义数据层变量
      • 第三步:将数据送入网络进行训练/预测
    2. 为什么需要数据准备?
    • 从文件读入数据。因为程序无法保存大量数据,数据一般保存到文件中,所以需要单独的数据读取操作
    • 批量快速读入。深度学习样本数据量较大,需要快速、高效读取(批量读取模式)
    • 随机读入。为了提高模型泛化能力,有时需要随机读取数据(随机读取模式)
    3. 代码
    import paddle
    
    
    # 原始读取器
    def reader_creator(file_path):
        def reader():
            with open(file_path, "r") as f:  # 打开文件
                lines = f.readlines()  # 读取所有行
                for line in lines:
                    yield line.replace("\n", "")  # 利用生成器关键字创建一个数据并返回
    
        return reader
    
    
    reader = reader_creator("test.txt")  # 原始顺序读取器
    shuffle_reader = paddle.reader.shuffle(reader, 10)  # 随机读取器
    batch_reader = paddle.batch(shuffle_reader, 3)  # 批量随机读取器
    
    # for data in reader():  # 迭代
    # for data in shuffle_reader():  # 对随机读取器进行迭代
    for data in batch_reader():  # 对批量随机读取器进行迭代
        print(data, end="")
    
    """
    ['888888888888,8', '111111111111,1', '444444444444,4']['333333333333,3', '666666666666,6', '222222222222,2']['000000000000,0', '999999999999,9', '555555555555,5']['777777777777,7']
    """
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26

    二、模型保存与加载

    1. 预测模型保存与加载
    • 保存预测模型:
      • fluid.io.save_inference_model(dirname, feeded_var_names, target_vars, executor)
      • 参数说明:
        • dirname(str):保存预测model的路径
        • feeded_var_names(list[str]):预测需要feed的数据
        • target_vars(list[Variable]):保存预测结果的Variables
        • executor(Executor):executor保存inference model
    • 加载预测模型:
      • fluid.io.load_inference_model(dirname, executor)
      • 参数说明:
        • dirname(str):保存预测model的路径
        • executor(Executor):运行模型的Executor
      • 返回值说明:
        • Program:用于预测的Program
        • feed_target_names(str列表):预测Program中提供数据的变量的名称
        • fetch_targets(Variable列表):存放预测结果
    2. 增量模型保存与加载
    • 保存增量模型:
      • fluid.io.save_persistables(executor, dirname, main_program=None)
      • 参数说明:
        • executor(Executor):保存变量的executor
        • dirname(str):保存模型的路径
        • main_program(Program|None):需要保存变量的Program。如果为None,则使用default_main_Program
    3. fluid API结构图

    在这里插入图片描述

    第十六章 综合案例:波士顿房价预测

    任务介绍

    1. 数据集及任务
    • 数据集介绍
      • 数据量:506笔
      • 特征数量:13个(见下图)
      • 标签:价格中位数
    • 任务:根据样本数据,预测房价中位数(回归问题)
      在这里插入图片描述
    2. 思路

    在这里插入图片描述

    3. 代码
    # 波士顿房价预测案例(多元回归)
    """
    数据集:包含506笔房价数据,每笔数据13个特征、1个标签
    """
    import os.path
    
    import paddle
    import paddle.fluid as fluid
    import numpy as np
    import matplotlib.pyplot as plt
    
    # 第一步:数据准备
    # 缓冲区
    BUF_SIZE = 500
    # 批次大小
    BATCH_SIZE = 20
    
    random_reader = paddle.reader.shuffle(paddle.dataset.uci_housing.train(),  # 训练集reader
                                          buf_size=BUF_SIZE)
    train_reader = paddle.batch(random_reader, batch_size=BATCH_SIZE)  # 批量读取器
    # # 打印数据
    # train_data = paddle.dataset.uci_housing.train()
    # for sample in train_data():
    #     print(sample)
    """
    # 13个特征
    # 标签:房屋的价格中位数
    (array([ 0.23814999, -0.11363636,  0.25525005, -0.06916996,  0.28457807,
           -0.17927465,  0.2824418 , -0.1902575 ,  0.62828665,  0.49191383,
            0.18558153,  0.10143217,  0.19638346]), array([8.3]))
    """
    # 第二步:模型搭建
    x = fluid.layers.data(name="x", shape=[13], dtype="float32")
    y = fluid.layers.data(name="y", shape=[1], dtype="float32")
    # 定义全连接模型
    y_predict = fluid.layers.fc(input=x,  # 输入
                                size=1,  # 输出值的个数
                                act=None)  # 激活函数
    # 损失函数
    cost = fluid.layers.square_error_cost(input=y_predict,  # 预测值
                                          label=y)  # 真实值
    avg_cost = fluid.layers.mean(cost)  # 均方差
    # 优化器
    optimizer = fluid.optimizer.SGD(learning_rate=0.001)
    optimizer.minimize(avg_cost)  # 指定优化的目标函数
    # 第三步:模型训练、保存
    place = fluid.CPUPlace()
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())
    # feeder:参数喂入器,能对参数格式转换,转为模型所需要的张量格式
    feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
    iter = 0
    iters = []
    train_costs = []
    EPOCHE_NUM = 120
    model_save_dir = "model/uci_housing"  # 模型保存路径
    
    for pass_id in range(EPOCHE_NUM):
        train_cost = 0
        i = 0
        for data in train_reader():
            i += 1
            train_cost = exe.run(program=fluid.default_main_program(),
                                 feed=feeder.feed(data),
                                 fetch_list=[avg_cost])
            if i % 20 == 0:
                print("pass_id:%d, cost:%f" % (pass_id, train_cost[0][0]))
            iter = iter + BATCH_SIZE
            iters.append(iter)  # 记录训练次数
            train_costs.append(train_cost[0][0])  # 记录损失值
    
    # 保存模型
    if not os.path.exists(model_save_dir):
        os.makedirs(model_save_dir)
    fluid.io.save_inference_model(model_save_dir,  # 模型保存路径
                                  ["x"],  # 预测时需要喂入的参数
                                  [y_predict],  # 模型预测的结果从哪里获取
                                  exe)  # 模型
    
    # 训练过程可视化
    plt.figure("Training Cost")
    plt.title("Training Cost", fontsize=24)
    plt.xlabel("iter", fontsize=14)
    plt.ylabel("cost", fontsize=14)
    plt.plot(iters, train_costs, color="red", label="Training Cost")
    plt.grid()
    plt.savefig("train.png")
    # 第四步:模型加载、预测
    infer_exe = fluid.Executor(place)
    infer_result = []  # 预测值列表
    ground_truths = []  # 真实值列表
    
    # 加载模型
    infer_program, feed_target_names, fetch_targets = \
        fluid.io.load_inference_model(model_save_dir,  # 模型保存路径
                                      infer_exe)  # 要加载到哪个执行器上
    
    # 测试集读取reader
    infer_reader = paddle.batch(paddle.dataset.uci_housing.test(), # 读取测试集
                                batch_size=200)
    test_data = next(infer_reader()) # 获取一批数据
    test_x = np.array([data[0] for data in test_data]).astype("float32")
    test_y = np.array([data[1] for data in test_data]).astype("float32")
    # 构建参数字典
    x_name = feed_target_names[0] # 获取参数名称
    results = infer_exe.run(infer_program, # 执行预测的program
                            feed={x_name: test_x}, # 参数
                            fetch_list=fetch_targets) # 获取预测结果
    # 预测值列表
    for idx, val in enumerate(results[0]):
        print("%d: %f" % (idx, val))
        infer_result.append(val)
    
    # 真实值列表
    for idx, val in enumerate(test_y):
        print("%d: %f" % (idx, val))
        ground_truths.append(val)
    
    # 将预测结果可视化
    plt.figure("infer")
    plt.title("infer", fontsize=24)
    plt.xlabel("ground truth", fontsize=14)
    plt.ylabel("infer result", fontsize=14)
    x = np.arange(1, 30)
    y = x
    plt.plot(x, y) # 绘制y=x斜线
    plt.scatter(ground_truths, infer_result, color="green", label="infer")
    plt.grid()
    plt.legend()
    plt.savefig("predict.png")
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    4. 执行结果
    pass_id:0, cost:656.089905
    pass_id:1, cost:451.363678
    pass_id:2, cost:584.316589
    pass_id:3, cost:477.108948
    pass_id:4, cost:286.599640
    pass_id:5, cost:445.707367
    pass_id:6, cost:446.433838
    pass_id:7, cost:335.848511
    pass_id:8, cost:273.062225
    pass_id:9, cost:255.784912
    pass_id:10, cost:278.432373
    pass_id:11, cost:276.121887
    pass_id:12, cost:175.726196
    pass_id:13, cost:170.238754
    pass_id:14, cost:170.852570
    pass_id:15, cost:223.544922
    pass_id:16, cost:166.904495
    pass_id:17, cost:321.751526
    pass_id:18, cost:280.356567
    pass_id:19, cost:111.091576
    pass_id:20, cost:124.681442
    pass_id:21, cost:64.695580
    pass_id:22, cost:129.477448
    pass_id:23, cost:133.440948
    pass_id:24, cost:130.348145
    pass_id:25, cost:102.667458
    pass_id:26, cost:64.281265
    pass_id:27, cost:142.222763
    pass_id:28, cost:26.178593
    pass_id:29, cost:220.263596
    pass_id:30, cost:169.756500
    pass_id:31, cost:119.223656
    pass_id:32, cost:87.624367
    pass_id:33, cost:59.109009
    pass_id:34, cost:164.397720
    pass_id:35, cost:98.710800
    pass_id:36, cost:117.974304
    pass_id:37, cost:64.506653
    pass_id:38, cost:104.113625
    pass_id:39, cost:78.288124
    pass_id:40, cost:103.716820
    pass_id:41, cost:95.082603
    pass_id:42, cost:34.526741
    pass_id:43, cost:99.486519
    pass_id:44, cost:122.406921
    pass_id:45, cost:176.348862
    pass_id:46, cost:56.122032
    pass_id:47, cost:48.959282
    pass_id:48, cost:114.838989
    pass_id:49, cost:173.656082
    pass_id:50, cost:96.353210
    pass_id:51, cost:129.643478
    pass_id:52, cost:78.118484
    pass_id:53, cost:56.693672
    pass_id:54, cost:67.857742
    pass_id:55, cost:15.136653
    pass_id:56, cost:87.636497
    pass_id:57, cost:45.029011
    pass_id:58, cost:108.201218
    pass_id:59, cost:32.179466
    pass_id:60, cost:34.872448
    pass_id:61, cost:94.557373
    pass_id:62, cost:127.176132
    pass_id:63, cost:81.021133
    pass_id:64, cost:27.862711
    pass_id:65, cost:75.477615
    pass_id:66, cost:119.252541
    pass_id:67, cost:93.257736
    pass_id:68, cost:25.911819
    pass_id:69, cost:17.109428
    pass_id:70, cost:35.836407
    pass_id:71, cost:69.057404
    pass_id:72, cost:112.613510
    pass_id:73, cost:68.981125
    pass_id:74, cost:49.957832
    pass_id:75, cost:20.481647
    pass_id:76, cost:59.729126
    pass_id:77, cost:45.460415
    pass_id:78, cost:22.951813
    pass_id:79, cost:40.394081
    pass_id:80, cost:37.409126
    pass_id:81, cost:41.443184
    pass_id:82, cost:70.590271
    pass_id:83, cost:54.799217
    pass_id:84, cost:41.712090
    pass_id:85, cost:79.634201
    pass_id:86, cost:103.184982
    pass_id:87, cost:15.930639
    pass_id:88, cost:97.250771
    pass_id:89, cost:43.428303
    pass_id:90, cost:104.876076
    pass_id:91, cost:71.580521
    pass_id:92, cost:38.239330
    pass_id:93, cost:16.533834
    pass_id:94, cost:43.827812
    pass_id:95, cost:16.911013
    pass_id:96, cost:66.245995
    pass_id:97, cost:45.150234
    pass_id:98, cost:13.511981
    pass_id:99, cost:41.205372
    pass_id:100, cost:17.888485
    pass_id:101, cost:51.672241
    pass_id:102, cost:54.815704
    pass_id:103, cost:20.194555
    pass_id:104, cost:110.166306
    pass_id:105, cost:53.912636
    pass_id:106, cost:26.374447
    pass_id:107, cost:14.297429
    pass_id:108, cost:23.325668
    pass_id:109, cost:60.575584
    pass_id:110, cost:46.281517
    pass_id:111, cost:162.359894
    pass_id:112, cost:46.856792
    pass_id:113, cost:101.333237
    pass_id:114, cost:45.367104
    pass_id:115, cost:36.769276
    pass_id:116, cost:31.477345
    pass_id:117, cost:59.371132
    pass_id:118, cost:19.479343
    pass_id:119, cost:45.612984
    0: 14.273354
    1: 14.844604
    2: 13.930457
    3: 16.414011
    4: 14.629221
    5: 15.708012
    6: 15.217474
    7: 14.663005
    8: 11.310017
    9: 14.527328
    10: 10.711575
    11: 13.088373
    12: 13.984264
    13: 13.269247
    14: 13.599314
    15: 14.719706
    16: 16.353481
    17: 16.084949
    18: 16.393032
    19: 14.104583
    20: 14.957327
    21: 13.355427
    22: 15.661732
    23: 15.234718
    24: 14.793789
    25: 14.038326
    26: 15.540333
    27: 15.426285
    28: 16.688440
    29: 15.434376
    30: 15.232512
    31: 14.383001
    32: 14.589976
    33: 12.967784
    34: 12.295956
    35: 15.128961
    36: 15.319921
    37: 16.048567
    38: 16.329090
    39: 16.142662
    40: 14.333231
    41: 13.827244
    42: 15.941467
    43: 16.392628
    44: 16.204958
    45: 15.768801
    46: 14.903126
    47: 16.429163
    48: 16.509726
    49: 17.171146
    50: 14.705102
    51: 15.027884
    52: 14.286346
    53: 14.653322
    54: 16.267881
    55: 16.920368
    56: 16.329792
    57: 17.043941
    58: 17.219616
    59: 17.729343
    60: 17.799040
    61: 17.409719
    62: 14.858363
    63: 15.860466
    64: 16.875864
    65: 17.623283
    66: 17.241474
    67: 17.786171
    68: 17.867687
    69: 18.561293
    70: 15.950287
    71: 15.370593
    72: 16.801212
    73: 14.720805
    74: 16.501842
    75: 17.383152
    76: 18.622833
    77: 19.182955
    78: 19.481457
    79: 18.825390
    80: 18.174511
    81: 18.768740
    82: 17.449657
    83: 18.290375
    84: 17.046276
    85: 15.830663
    86: 14.686304
    87: 17.283913
    88: 18.288227
    89: 22.529139
    90: 22.696047
    91: 22.225986
    92: 20.664625
    93: 21.986166
    94: 22.403265
    95: 21.603868
    96: 21.927729
    97: 23.351410
    98: 22.955769
    99: 23.764650
    100: 23.540255
    101: 22.982683
    0: 8.500000
    1: 5.000000
    2: 11.900000
    3: 27.900000
    4: 17.200001
    5: 27.500000
    6: 15.000000
    7: 17.200001
    8: 17.900000
    9: 16.299999
    10: 7.000000
    11: 7.200000
    12: 7.500000
    13: 10.400000
    14: 8.800000
    15: 8.400000
    16: 16.700001
    17: 14.200000
    18: 20.799999
    19: 13.400000
    20: 11.700000
    21: 8.300000
    22: 10.200000
    23: 10.900000
    24: 11.000000
    25: 9.500000
    26: 14.500000
    27: 14.100000
    28: 16.100000
    29: 14.300000
    30: 11.700000
    31: 13.400000
    32: 9.600000
    33: 8.700000
    34: 8.400000
    35: 12.800000
    36: 10.500000
    37: 17.100000
    38: 18.400000
    39: 15.400000
    40: 10.800000
    41: 11.800000
    42: 14.900000
    43: 12.600000
    44: 14.100000
    45: 13.000000
    46: 13.400000
    47: 15.200000
    48: 16.100000
    49: 17.799999
    50: 14.900000
    51: 14.100000
    52: 12.700000
    53: 13.500000
    54: 14.900000
    55: 20.000000
    56: 16.400000
    57: 17.700001
    58: 19.500000
    59: 20.200001
    60: 21.400000
    61: 19.900000
    62: 19.000000
    63: 19.100000
    64: 19.100000
    65: 20.100000
    66: 19.900000
    67: 19.600000
    68: 23.200001
    69: 29.799999
    70: 13.800000
    71: 13.300000
    72: 16.700001
    73: 12.000000
    74: 14.600000
    75: 21.400000
    76: 23.000000
    77: 23.700001
    78: 25.000000
    79: 21.799999
    80: 20.600000
    81: 21.200001
    82: 19.100000
    83: 20.600000
    84: 15.200000
    85: 7.000000
    86: 8.100000
    87: 13.600000
    88: 20.100000
    89: 21.799999
    90: 24.500000
    91: 23.100000
    92: 19.700001
    93: 18.299999
    94: 21.200001
    95: 17.500000
    96: 16.799999
    97: 22.400000
    98: 20.600000
    99: 23.900000
    100: 22.000000
    101: 11.900000
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210
    • 211
    • 212
    • 213
    • 214
    • 215
    • 216
    • 217
    • 218
    • 219
    • 220
    • 221
    • 222
    • 223
    • 224
    • 225
    • 226
    • 227
    • 228
    • 229
    • 230
    • 231
    • 232
    • 233
    • 234
    • 235
    • 236
    • 237
    • 238
    • 239
    • 240
    • 241
    • 242
    • 243
    • 244
    • 245
    • 246
    • 247
    • 248
    • 249
    • 250
    • 251
    • 252
    • 253
    • 254
    • 255
    • 256
    • 257
    • 258
    • 259
    • 260
    • 261
    • 262
    • 263
    • 264
    • 265
    • 266
    • 267
    • 268
    • 269
    • 270
    • 271
    • 272
    • 273
    • 274
    • 275
    • 276
    • 277
    • 278
    • 279
    • 280
    • 281
    • 282
    • 283
    • 284
    • 285
    • 286
    • 287
    • 288
    • 289
    • 290
    • 291
    • 292
    • 293
    • 294
    • 295
    • 296
    • 297
    • 298
    • 299
    • 300
    • 301
    • 302
    • 303
    • 304
    • 305
    • 306
    • 307
    • 308
    • 309
    • 310
    • 311
    • 312
    • 313
    • 314
    • 315
    • 316
    • 317
    • 318
    • 319
    • 320
    • 321
    • 322
    • 323
    • 324

    在这里插入图片描述

  • 相关阅读:
    如何合并pdf文件到一个pdf
    Nacos Config--服务配置
    Ubuntu下vscode配置OpenCV以及Libtorch
    Java集合常见面试题汇总
    安信可Ai-WB1系列固件烧录指导
    CodeTON Round 3 (Div. 1 + Div. 2, Rated, Prizes!) A~D 题解
    Docker面试整理-如何进行Docker镜像的构建和发布?
    hot100-数组中的第k个最大元素
    Java数据结构:前缀、中缀、后缀表达式与逆波兰计算器的实现
    六大招式,修炼极狐GitLab CI/CD “快” 字诀
  • 原文地址:https://blog.csdn.net/sgsgkxkx/article/details/126648226