• 【深度学习】实验05 构造神经网络示例


    构造神经网络

    神经网络是一种仿生学原理的机器学习算法,灵感来源于人脑的神经系统。它由多个神经元(或称为节点)组成,这些神经元通过连接权重形成复杂的网络结构,用来学习和提取输入数据的特征,并用于分类、回归、聚类等任务。
    注明:该代码用来训练一个神经网络,网络拟合y = x^2-0.5+noise,该神经网络的结构是输入层为一个神经元,隐藏层为十个神经元,输出层为一个神经元

    1. 导入相关库

    # 导入相关库
    import tensorflow as tf  # 用来构造神经网络
    import numpy as np  # 用来构造数据结构和处理数据模块
    
    • 1
    • 2
    • 3

    这段代码使用了两个 Python 模块:

    1. tensorflow:这是 Google 开源的机器学习框架,用来构造神经网络和训练模型。

    2. numpy:这是 Python 中用于矩阵/数组运算的基础库,用来构造数据结构和处理数据。

    具体来说:

    • import tensorflow as tf 导入 TensorFlow 库并给它起个别名 tf
    • import numpy as np 导入 NumPy 库并给它起个别名 np

    2. 定义一个层

    # 定义一个层
    def add_layer(inputs, in_size, out_size, activation_function=None):
        # 定义一个层,其中inputs为输入,in_size为上一层神经元数,out_size为该层神经元数
        # activation_function为激励函数
        Weights = tf.Variable(tf.random_normal([in_size, out_size]))
        # 初始权重随机生成比较好,in_size,out_size为该权重维度
        biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
        # 偏置
        Wx_plus_b = tf.matmul(inputs, Weights) + biases
        # matmul为矩阵里的函数相乘
        if activation_function is None:
            outputs = Wx_plus_b  # 如果激活函数为空,则不激活,保持数据
        else:
            outputs = activation_function(Wx_plus_b)
            # 如果激活函数不为空,则激活,并且返回激活后的值
        return outputs  # 返回激活后的值
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    这段代码定义了一个函数 add_layer,用于添加一层神经网络。

    代码中的参数解释如下:

    • inputs:该层的输入。
    • in_size:该层的输入维度,即上一层神经元数。
    • out_size:该层的输出维度,即该层神经元数。
    • activation_function:该层使用的激活函数,可以为空。

    函数内部逻辑:

    • Weights = tf.Variable(tf.random_normal([in_size, out_size])):定义该层的权重,使用随机生成的正态分布数据,维度为 [in_size, out_size]
    • biases = tf.Variable(tf.zeros([1, out_size]) + 0.1):定义该层的偏置,使用全0矩阵,维度为 [1, out_size],并且加上0.1,以避免点评中出现0的情况。
    • Wx_plus_b = tf.matmul(inputs, Weights) + biases:使用矩阵乘法计算该层的输出,即加权和加上偏置。
    • if activation_function is None::如果激活函数为空,则直接将加权和加上偏置的结果作为该层的输出。
    • else::否则,对加权和加上偏置的结果进行激活函数的处理,并将处理结果作为该层的输出。
    • return outputs:返回该层的输出。

    总的来说,该函数的作用是创建一个神经网络层,将输入经过加权和加上偏置的运算,并使用激活函数得到输出。

    3. 构造数据集

    # 构造一些样本,用来训练神经网络
    x_data = np.linspace(-1, 1, 300)[:, np.newaxis]
    # 值为(-1,1)之间的数,有300个
    noise = np.random.normal(0, 0.05, x_data.shape)
    x_data
    
    • 1
    • 2
    • 3
    • 4
    • 5
       array([[-1.        ],
              [-0.99331104],
              [-0.98662207],
              [-0.97993311],
              [-0.97324415],
              [-0.96655518],
              [-0.95986622],
              [-0.95317726],
              [-0.94648829],
              [-0.93979933],
              [-0.93311037],
              [-0.9264214 ],
              [-0.91973244],
              [-0.91304348],
              [-0.90635452],
              [-0.89966555],
              [-0.89297659],
              [-0.88628763],
              [-0.87959866],
              [-0.8729097 ],
              [-0.86622074],
              [-0.85953177],
              [-0.85284281],
              [-0.84615385],
              [-0.83946488],
              [-0.83277592],
              [-0.82608696],
              [-0.81939799],
              [-0.81270903],
              [-0.80602007],
              [-0.7993311 ],
              [-0.79264214],
              [-0.78595318],
              [-0.77926421],
              [-0.77257525],
              [-0.76588629],
              [-0.75919732],
              [-0.75250836],
              [-0.7458194 ],
              [-0.73913043],
              [-0.73244147],
              [-0.72575251],
              [-0.71906355],
              [-0.71237458],
              [-0.70568562],
              [-0.69899666],
              [-0.69230769],
              [-0.68561873],
              [-0.67892977],
              [-0.6722408 ],
              [-0.66555184],
              [-0.65886288],
              [-0.65217391],
              [-0.64548495],
              [-0.63879599],
              [-0.63210702],
              [-0.62541806],
              [-0.6187291 ],
              [-0.61204013],
              [-0.60535117],
              [-0.59866221],
              [-0.59197324],
              [-0.58528428],
              [-0.57859532],
              [-0.57190635],
              [-0.56521739],
              [-0.55852843],
              [-0.55183946],
              [-0.5451505 ],
              [-0.53846154],
              [-0.53177258],
              [-0.52508361],
              [-0.51839465],
              [-0.51170569],
              [-0.50501672],
              [-0.49832776],
              [-0.4916388 ],
              [-0.48494983],
              [-0.47826087],
              [-0.47157191],
              [-0.46488294],
              [-0.45819398],
              [-0.45150502],
              [-0.44481605],
              [-0.43812709],
              [-0.43143813],
              [-0.42474916],
              [-0.4180602 ],
              [-0.41137124],
              [-0.40468227],
              [-0.39799331],
              [-0.39130435],
              [-0.38461538],
              [-0.37792642],
              [-0.37123746],
              [-0.36454849],
              [-0.35785953],
              [-0.35117057],
              [-0.34448161],
              [-0.33779264],
              [-0.33110368],
              [-0.32441472],
              [-0.31772575],
              [-0.31103679],
              [-0.30434783],
              [-0.29765886],
              [-0.2909699 ],
              [-0.28428094],
              [-0.27759197],
              [-0.27090301],
              [-0.26421405],
              [-0.25752508],
              [-0.25083612],
              [-0.24414716],
              [-0.23745819],
              [-0.23076923],
              [-0.22408027],
              [-0.2173913 ],
              [-0.21070234],
              [-0.20401338],
              [-0.19732441],
              [-0.19063545],
              [-0.18394649],
              [-0.17725753],
              [-0.17056856],
              [-0.1638796 ],
              [-0.15719064],
              [-0.15050167],
              [-0.14381271],
              [-0.13712375],
              [-0.13043478],
              [-0.12374582],
              [-0.11705686],
              [-0.11036789],
              [-0.10367893],
              [-0.09698997],
              [-0.090301  ],
              [-0.08361204],
              [-0.07692308],
              [-0.07023411],
              [-0.06354515],
              [-0.05685619],
              [-0.05016722],
              [-0.04347826],
              [-0.0367893 ],
              [-0.03010033],
              [-0.02341137],
              [-0.01672241],
              [-0.01003344],
              [-0.00334448],
              [ 0.00334448],
              [ 0.01003344],
              [ 0.01672241],
              [ 0.02341137],
              [ 0.03010033],
              [ 0.0367893 ],
              [ 0.04347826],
              [ 0.05016722],
              [ 0.05685619],
              [ 0.06354515],
              [ 0.07023411],
              [ 0.07692308],
              [ 0.08361204],
              [ 0.090301  ],
              [ 0.09698997],
              [ 0.10367893],
              [ 0.11036789],
              [ 0.11705686],
              [ 0.12374582],
              [ 0.13043478],
              [ 0.13712375],
              [ 0.14381271],
              [ 0.15050167],
              [ 0.15719064],
              [ 0.1638796 ],
              [ 0.17056856],
              [ 0.17725753],
              [ 0.18394649],
              [ 0.19063545],
              [ 0.19732441],
              [ 0.20401338],
              [ 0.21070234],
              [ 0.2173913 ],
              [ 0.22408027],
              [ 0.23076923],
              [ 0.23745819],
              [ 0.24414716],
              [ 0.25083612],
              [ 0.25752508],
              [ 0.26421405],
              [ 0.27090301],
              [ 0.27759197],
              [ 0.28428094],
              [ 0.2909699 ],
              [ 0.29765886],
              [ 0.30434783],
              [ 0.31103679],
              [ 0.31772575],
              [ 0.32441472],
              [ 0.33110368],
              [ 0.33779264],
              [ 0.34448161],
              [ 0.35117057],
              [ 0.35785953],
              [ 0.36454849],
              [ 0.37123746],
              [ 0.37792642],
              [ 0.38461538],
              [ 0.39130435],
              [ 0.39799331],
              [ 0.40468227],
              [ 0.41137124],
              [ 0.4180602 ],
              [ 0.42474916],
              [ 0.43143813],
              [ 0.43812709],
              [ 0.44481605],
              [ 0.45150502],
              [ 0.45819398],
              [ 0.46488294],
              [ 0.47157191],
              [ 0.47826087],
              [ 0.48494983],
              [ 0.4916388 ],
              [ 0.49832776],
              [ 0.50501672],
              [ 0.51170569],
              [ 0.51839465],
              [ 0.52508361],
              [ 0.53177258],
              [ 0.53846154],
              [ 0.5451505 ],
              [ 0.55183946],
              [ 0.55852843],
              [ 0.56521739],
              [ 0.57190635],
              [ 0.57859532],
              [ 0.58528428],
              [ 0.59197324],
              [ 0.59866221],
              [ 0.60535117],
              [ 0.61204013],
              [ 0.6187291 ],
              [ 0.62541806],
              [ 0.63210702],
              [ 0.63879599],
              [ 0.64548495],
              [ 0.65217391],
              [ 0.65886288],
              [ 0.66555184],
              [ 0.6722408 ],
              [ 0.67892977],
              [ 0.68561873],
              [ 0.69230769],
              [ 0.69899666],
              [ 0.70568562],
              [ 0.71237458],
              [ 0.71906355],
              [ 0.72575251],
              [ 0.73244147],
              [ 0.73913043],
              [ 0.7458194 ],
              [ 0.75250836],
              [ 0.75919732],
              [ 0.76588629],
              [ 0.77257525],
              [ 0.77926421],
              [ 0.78595318],
              [ 0.79264214],
              [ 0.7993311 ],
              [ 0.80602007],
              [ 0.81270903],
              [ 0.81939799],
              [ 0.82608696],
              [ 0.83277592],
              [ 0.83946488],
              [ 0.84615385],
              [ 0.85284281],
              [ 0.85953177],
              [ 0.86622074],
              [ 0.8729097 ],
              [ 0.87959866],
              [ 0.88628763],
              [ 0.89297659],
              [ 0.89966555],
              [ 0.90635452],
              [ 0.91304348],
              [ 0.91973244],
              [ 0.9264214 ],
              [ 0.93311037],
              [ 0.93979933],
              [ 0.94648829],
              [ 0.95317726],
              [ 0.95986622],
              [ 0.96655518],
              [ 0.97324415],
              [ 0.97993311],
              [ 0.98662207],
              [ 0.99331104],
              [ 1.        ]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210
    • 211
    • 212
    • 213
    • 214
    • 215
    • 216
    • 217
    • 218
    • 219
    • 220
    • 221
    • 222
    • 223
    • 224
    • 225
    • 226
    • 227
    • 228
    • 229
    • 230
    • 231
    • 232
    • 233
    • 234
    • 235
    • 236
    • 237
    • 238
    • 239
    • 240
    • 241
    • 242
    • 243
    • 244
    • 245
    • 246
    • 247
    • 248
    • 249
    • 250
    • 251
    • 252
    • 253
    • 254
    • 255
    • 256
    • 257
    • 258
    • 259
    • 260
    • 261
    • 262
    • 263
    • 264
    • 265
    • 266
    • 267
    • 268
    • 269
    • 270
    • 271
    • 272
    • 273
    • 274
    • 275
    • 276
    • 277
    • 278
    • 279
    • 280
    • 281
    • 282
    • 283
    • 284
    • 285
    • 286
    • 287
    • 288
    • 289
    • 290
    • 291
    • 292
    • 293
    • 294
    • 295
    • 296
    • 297
    • 298
    • 299
    • 300

    这段代码使用 numpy 库创建了一个一维的数组 x_data

    代码中的参数解释如下:

    • -1:数组中数的最小值。
    • 1:数组中数的最大值。
    • 300:数组中数的个数。
    • [:, np.newaxis]:对数组进行转置,转换成二维数组。

    函数内部逻辑:

    • np.linspace(-1, 1, 300):返回一个数值范围在 -1 到 1 之间,总共有 300 个数的等差数列。即生成一个 ndarray 数组,包含 300 个数,分布在 -1 到 1 之间。
    • [:, np.newaxis]:将一维数组转化成列向量,即让数组的 shape 从 (300,) 变成 (300, 1)

    最终生成的 x_data 是一个二维数组,第一维度为 300,第二维度为 1,表示由 300 个样本组成,每个样本只有一个特征。

    # 加入噪声会更贴近真实情况,噪声的值为(0,0.05)之间,结构为x_data一样
    y_data = np.square(x_data) - 0.5 + noise
    # y的结构
    y_data
    
    • 1
    • 2
    • 3
    • 4
       array([[ 0.59535036],
              [ 0.46017998],
              [ 0.47144478],
              [ 0.45083795],
              [ 0.58438217],
              [ 0.38570118],
              [ 0.43550029],
              [ 0.40597571],
              [ 0.3357524 ],
              [ 0.35784864],
              [ 0.34530231],
              [ 0.32509701],
              [ 0.25554733],
              [ 0.32300801],
              [ 0.2299959 ],
              [ 0.35472568],
              [ 0.31227671],
              [ 0.30385068],
              [ 0.29413844],
              [ 0.18437787],
              [ 0.28132819],
              [ 0.25605309],
              [ 0.23126361],
              [ 0.23492797],
              [ 0.18381621],
              [ 0.10392937],
              [ 0.13415913],
              [ 0.14043649],
              [ 0.11756826],
              [ 0.12142749],
              [ 0.12400694],
              [ 0.08926307],
              [ 0.15581832],
              [ 0.16541106],
              [-0.02582895],
              [ 0.05924725],
              [-0.04037454],
              [ 0.03799003],
              [ 0.09030832],
              [ 0.05984324],
              [-0.06569464],
              [ 0.07973773],
              [ 0.04297837],
              [ 0.05169557],
              [-0.00096191],
              [-0.02049573],
              [-0.03125322],
              [-0.04545588],
              [-0.02168901],
              [ 0.01657517],
              [-0.04315181],
              [-0.09123519],
              [-0.03292835],
              [-0.1110189 ],
              [-0.08212792],
              [-0.10089535],
              [-0.17406672],
              [-0.10380731],
              [-0.10774072],
              [-0.21283138],
              [-0.09788435],
              [-0.10196452],
              [-0.16439081],
              [-0.15431978],
              [-0.17778307],
              [-0.18428537],
              [-0.17874028],
              [-0.10490738],
              [-0.25076832],
              [-0.16078044],
              [-0.21572183],
              [-0.15624353],
              [-0.19591988],
              [-0.31560742],
              [-0.29593726],
              [-0.26686787],
              [-0.2999804 ],
              [-0.30631065],
              [-0.35305224],
              [-0.31295125],
              [-0.22996255],
              [-0.22837061],
              [-0.27266253],
              [-0.31290802],
              [-0.37188479],
              [-0.20765034],
              [-0.33860431],
              [-0.31135236],
              [-0.25249981],
              [-0.26041048],
              [-0.31486205],
              [-0.30253306],
              [-0.41624795],
              [-0.40053837],
              [-0.29939676],
              [-0.32615377],
              [-0.37377787],
              [-0.32222027],
              [-0.3158838 ],
              [-0.43880087],
              [-0.37510637],
              [-0.46702321],
              [-0.27058091],
              [-0.52885151],
              [-0.4061462 ],
              [-0.4486374 ],
              [-0.37819628],
              [-0.34701947],
              [-0.32454364],
              [-0.3901839 ],
              [-0.43293107],
              [-0.47881173],
              [-0.45280819],
              [-0.49676541],
              [-0.48955669],
              [-0.45898691],
              [-0.37473462],
              [-0.43801531],
              [-0.44793655],
              [-0.57343047],
              [-0.45262969],
              [-0.40719677],
              [-0.45423461],
              [-0.45053051],
              [-0.51046881],
              [-0.41584096],
              [-0.53328545],
              [-0.44766406],
              [-0.50158463],
              [-0.42676031],
              [-0.50552613],
              [-0.36832989],
              [-0.48699296],
              [-0.41614151],
              [-0.6175621 ],
              [-0.48304532],
              [-0.46115021],
              [-0.40948908],
              [-0.42017024],
              [-0.50411757],
              [-0.44530626],
              [-0.46895275],
              [-0.52127771],
              [-0.50064585],
              [-0.42210169],
              [-0.58582837],
              [-0.52049198],
              [-0.45332091],
              [-0.53465815],
              [-0.5385712 ],
              [-0.5654201 ],
              [-0.54471377],
              [-0.48109194],
              [-0.44565732],
              [-0.48112022],
              [-0.46471786],
              [-0.5452149 ],
              [-0.52115601],
              [-0.50234928],
              [-0.54885558],
              [-0.5279981 ],
              [-0.53893795],
              [-0.44286416],
              [-0.45371406],
              [-0.44633111],
              [-0.57535678],
              [-0.62918947],
              [-0.41877124],
              [-0.56263956],
              [-0.51201705],
              [-0.35016007],
              [-0.49188897],
              [-0.55766056],
              [-0.38963378],
              [-0.5038024 ],
              [-0.51949984],
              [-0.45229896],
              [-0.49193029],
              [-0.53472883],
              [-0.48957523],
              [-0.35561181],
              [-0.4622668 ],
              [-0.39177781],
              [-0.43448445],
              [-0.49854629],
              [-0.49843105],
              [-0.47704375],
              [-0.36618194],
              [-0.45177012],
              [-0.41497222],
              [-0.42152064],
              [-0.48996608],
              [-0.43010878],
              [-0.42599962],
              [-0.2841197 ],
              [-0.38992082],
              [-0.43802592],
              [-0.42448799],
              [-0.29514676],
              [-0.37154091],
              [-0.25426219],
              [-0.44610678],
              [-0.37120566],
              [-0.3531599 ],
              [-0.34606119],
              [-0.29637877],
              [-0.3693284 ],
              [-0.36651142],
              [-0.30025118],
              [-0.31443603],
              [-0.40824064],
              [-0.31734053],
              [-0.40807378],
              [-0.33792031],
              [-0.22414921],
              [-0.37707072],
              [-0.26776417],
              [-0.29152204],
              [-0.34066934],
              [-0.19037511],
              [-0.23552614],
              [-0.2144995 ],
              [-0.27628531],
              [-0.27329725],
              [-0.23910513],
              [-0.30009859],
              [-0.30192088],
              [-0.16403744],
              [-0.32546893],
              [-0.25686912],
              [-0.12515146],
              [-0.21483097],
              [-0.12779443],
              [-0.28748063],
              [-0.23782354],
              [-0.16024807],
              [-0.19062672],
              [-0.15066097],
              [-0.19043274],
              [-0.16583211],
              [-0.11201314],
              [-0.05612149],
              [-0.00847256],
              [-0.1429705 ],
              [-0.09595988],
              [-0.09583441],
              [-0.01372838],
              [-0.04818834],
              [-0.11840653],
              [ 0.02184166],
              [-0.07153294],
              [-0.11556547],
              [-0.04731049],
              [-0.10774914],
              [-0.014642  ],
              [-0.01470962],
              [-0.03259555],
              [-0.04194347],
              [ 0.08987345],
              [-0.02027899],
              [ 0.02418433],
              [ 0.04298611],
              [ 0.04130101],
              [ 0.18010436],
              [ 0.15480307],
              [ 0.02719993],
              [ 0.11508363],
              [ 0.04309794],
              [ 0.14060578],
              [ 0.09377926],
              [ 0.13887198],
              [ 0.16148276],
              [ 0.11398259],
              [ 0.27887578],
              [ 0.22775177],
              [ 0.20749998],
              [ 0.22107721],
              [ 0.20854961],
              [ 0.25411644],
              [ 0.26561906],
              [ 0.27540788],
              [ 0.26946028],
              [ 0.2390275 ],
              [ 0.26051795],
              [ 0.34424064],
              [ 0.3240088 ],
              [ 0.38040554],
              [ 0.35717078],
              [ 0.31357911],
              [ 0.43825368],
              [ 0.35709739],
              [ 0.48101049],
              [ 0.36024364],
              [ 0.43253108],
              [ 0.39268334],
              [ 0.41942572],
              [ 0.41196584],
              [ 0.54435941],
              [ 0.49840622],
              [ 0.51627957]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210
    • 211
    • 212
    • 213
    • 214
    • 215
    • 216
    • 217
    • 218
    • 219
    • 220
    • 221
    • 222
    • 223
    • 224
    • 225
    • 226
    • 227
    • 228
    • 229
    • 230
    • 231
    • 232
    • 233
    • 234
    • 235
    • 236
    • 237
    • 238
    • 239
    • 240
    • 241
    • 242
    • 243
    • 244
    • 245
    • 246
    • 247
    • 248
    • 249
    • 250
    • 251
    • 252
    • 253
    • 254
    • 255
    • 256
    • 257
    • 258
    • 259
    • 260
    • 261
    • 262
    • 263
    • 264
    • 265
    • 266
    • 267
    • 268
    • 269
    • 270
    • 271
    • 272
    • 273
    • 274
    • 275
    • 276
    • 277
    • 278
    • 279
    • 280
    • 281
    • 282
    • 283
    • 284
    • 285
    • 286
    • 287
    • 288
    • 289
    • 290
    • 291
    • 292
    • 293
    • 294
    • 295
    • 296
    • 297
    • 298
    • 299
    • 300

    这段代码是根据 x_data 生成对应的 y_data,并加入了一些噪声。

    具体实现中,首先利用 np.square(x_data)x_data 中的每个元素平方,然后减去一个常数 0.5,最后加上一些噪声,生成与 x_data 形状相同的 y_data 数组。

    由于 x_data 是一个二维数组,y_data 需要与它形状相同,因此 y_data 也是一个二维数组,包含 300 个样本和每个样本的输出值。

    4. 定义基本模型

    # 定义placeholder用来输入数据到神经网络,其中1表只有一个特征,也就是维度为一维数据
    xs = tf.placeholder(tf.float32, [None, 1])
    ys = tf.placeholder(tf.float32, [None, 1])
    # add hidden layer
    l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
    # add output layer
    prediction = add_layer(l1, 10, 1, activation_function=None)
    
    # 代价函数,reduce_mean为求均值,reduce_sum为求和,reduction_indices为数据处理的维度
    loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),reduction_indices=[1]))
    
    # 将代价函数传到梯度下降,学习速率为0.1,这里包含权重的训练,会更新权重
    train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    这段代码定义了一个神经网络模型。

    首先,使用 tf.placeholder 创建两个占位符 xsys,分别用来输入训练数据和真实标签,其中 None 表示样本数量是不确定的,只确定数据的维度是 1 维。

    接下来,调用 add_layer 函数添加一个隐藏层,输入为 xs神经元个数为 10 个,激活函数为 ReLU。然后再调用一次 add_layer 函数添加一个输出层,输入为隐藏层的输出,神经元个数为 1 个,激活函数为 None(也就是不使用激活函数)。

    然后,定义了一个代价函数 loss,用来衡量预测值与真实标签之间的差距,这里选用的是 mean square error(均方误差)作为代价函数。具体实现中,使用 tf.square 计算每个样本的预测值与真实标签之间的差距,然后使用 tf.reduce_mean 计算所有样本的差距的平均值。

    最后,使用 tf.train.GradientDescentOptimizer 创建一个优化器,设定学习速率为 0.1,然后调用 minimize 方法去最小化代价函数 loss,这里会更新神经网络的权重和偏置,训练模型使得预测值与真实标签不断接近。

    5. 变量初始化

    # important step
    # tf.initialize_all_variables() no long valid from
    # 2017-03-02 if using tensorflow >= 0.12
    # 变量初始化
    if int((tf.__version__).split('.')[1]) < 12:
        init = tf.initialize_all_variables()
    else:
        init = tf.global_variables_initializer()
    sess = tf.Session()  # 打开TensorFlow
    sess.run(init)  # 执行变量初始化
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    这段代码主要是进行了 TensorFlow 的初始化操作。由于 TensorFlow 的版本问题,原来的 tf.initialize_all_variables() 不再被支持,改为了 tf.global_variables_initializer()。然后创建了一个 tf.Session 对象 sess,用来执行 TensorFlow 中定义的操作。

    最后,执行 init 操作进行变量的初始化。这里会将之前定义的变量(包括权重和偏置)都初始化为一些随机值,用来开始训练模型。

    6. 开始训练

    for i in range(1000):  # 梯度下降迭代一千次
        # training
        sess.run(train_step, feed_dict={xs: x_data, ys: y_data})
        # 执行梯度下降算法,并且将样本喂给损失函数
        if i % 50 == 0:
            # 每50次迭代输出代价函数的值
            print(sess.run(loss, feed_dict={xs: x_data, ys: y_data}))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    0.18214862
    0.010138167
    0.0071248626
    0.0069830194
    0.0068635535
    0.0067452225
    0.006626569
    0.0065121166
    0.0064035906
    0.006295418
    0.0061897114
    0.0060903295
    0.005990808
    0.0058959606
    0.0058057955
    0.0057200184
    0.005637601
    0.0055605737
    0.0054863705
    0.005413457
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20

    这段代码是用来训练模型的,实现了梯度下降的过程。循环了 1000 次,每 50 次迭代输出一次代价函数 loss 的值。在每次迭代中,通过 sess.run(train_step, feed_dict={xs: x_data, ys: y_data}) 执行了一次梯度下降,并把样本传入损失函数中进行计算。

    这个训练过程中的输出可以用来观察代价函数的变化情况,如果随着迭代的进行,代价函数的值逐渐减小,那么就表示模型的训练效果越来越好,模型越来越能够准确地预测目标变量。

    附:系列文章

    序号文章目录直达链接
    1波士顿房价预测https://want595.blog.csdn.net/article/details/132181950
    2鸢尾花数据集分析https://want595.blog.csdn.net/article/details/132182057
    3特征处理https://want595.blog.csdn.net/article/details/132182165
    4交叉验证https://want595.blog.csdn.net/article/details/132182238
    5构造神经网络示例https://want595.blog.csdn.net/article/details/132182341
    6使用TensorFlow完成线性回归https://want595.blog.csdn.net/article/details/132182417
    7使用TensorFlow完成逻辑回归https://want595.blog.csdn.net/article/details/132182496
    8TensorBoard案例https://want595.blog.csdn.net/article/details/132182584
    9使用Keras完成线性回归https://want595.blog.csdn.net/article/details/132182723
    10使用Keras完成逻辑回归https://want595.blog.csdn.net/article/details/132182795
    11使用Keras预训练模型完成猫狗识别https://want595.blog.csdn.net/article/details/132243928
    12使用PyTorch训练模型https://want595.blog.csdn.net/article/details/132243989
    13使用Dropout抑制过拟合https://want595.blog.csdn.net/article/details/132244111
    14使用CNN完成MNIST手写体识别(TensorFlow)https://want595.blog.csdn.net/article/details/132244499
    15使用CNN完成MNIST手写体识别(Keras)https://want595.blog.csdn.net/article/details/132244552
    16使用CNN完成MNIST手写体识别(PyTorch)https://want595.blog.csdn.net/article/details/132244641
    17使用GAN生成手写数字样本https://want595.blog.csdn.net/article/details/132244764
    18自然语言处理https://want595.blog.csdn.net/article/details/132276591
  • 相关阅读:
    集群-Nacos-2.2.3、Nginx-1.24.0集群配置
    mysql函数汇总之系统信息函数
    css3 hover效果
    要便利,更要安全可靠,数字钥匙优化升级迫在眉睫
    node 第十九天 使用node插件node-jsonwebtoken实现身份令牌jwt认证
    2022年笔试知识总结展望(前后端均有)
    一位末流211新大二同学的暑期总结
    蓝色荧光油溶性/三元核壳结构CuInS2/ZnS/亲水性CZTS量子点
    jenkins+git持续集成配置
    第一章 概述
  • 原文地址:https://blog.csdn.net/m0_68111267/article/details/132182341