神经网络是一种仿生学原理的机器学习算法,灵感来源于人脑的神经系统。它由多个神经元(或称为节点)组成,这些神经元通过连接权重形成复杂的网络结构,用来学习和提取输入数据的特征,并用于分类、回归、聚类等任务。
注明:该代码用来训练一个神经网络,网络拟合y = x^2-0.5+noise,该神经网络的结构是输入层为一个神经元,隐藏层为十个神经元,输出层为一个神经元
# 导入相关库
import tensorflow as tf # 用来构造神经网络
import numpy as np # 用来构造数据结构和处理数据模块
这段代码使用了两个 Python 模块:
tensorflow
:这是 Google 开源的机器学习框架,用来构造神经网络和训练模型。
numpy
:这是 Python 中用于矩阵/数组运算的基础库,用来构造数据结构和处理数据。
具体来说:
import tensorflow as tf
导入 TensorFlow 库并给它起个别名 tf
。import numpy as np
导入 NumPy 库并给它起个别名 np
。# 定义一个层
def add_layer(inputs, in_size, out_size, activation_function=None):
# 定义一个层,其中inputs为输入,in_size为上一层神经元数,out_size为该层神经元数
# activation_function为激励函数
Weights = tf.Variable(tf.random_normal([in_size, out_size]))
# 初始权重随机生成比较好,in_size,out_size为该权重维度
biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
# 偏置
Wx_plus_b = tf.matmul(inputs, Weights) + biases
# matmul为矩阵里的函数相乘
if activation_function is None:
outputs = Wx_plus_b # 如果激活函数为空,则不激活,保持数据
else:
outputs = activation_function(Wx_plus_b)
# 如果激活函数不为空,则激活,并且返回激活后的值
return outputs # 返回激活后的值
这段代码定义了一个函数 add_layer
,用于添加一层神经网络。
代码中的参数解释如下:
inputs
:该层的输入。in_size
:该层的输入维度,即上一层神经元数。out_size
:该层的输出维度,即该层神经元数。activation_function
:该层使用的激活函数,可以为空。函数内部逻辑:
Weights = tf.Variable(tf.random_normal([in_size, out_size]))
:定义该层的权重,使用随机生成的正态分布数据,维度为 [in_size, out_size]
。biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
:定义该层的偏置,使用全0矩阵,维度为 [1, out_size]
,并且加上0.1,以避免点评中出现0的情况。Wx_plus_b = tf.matmul(inputs, Weights) + biases
:使用矩阵乘法计算该层的输出,即加权和加上偏置。if activation_function is None:
:如果激活函数为空,则直接将加权和加上偏置的结果作为该层的输出。else:
:否则,对加权和加上偏置的结果进行激活函数的处理,并将处理结果作为该层的输出。return outputs
:返回该层的输出。总的来说,该函数的作用是创建一个神经网络层,将输入经过加权和加上偏置的运算,并使用激活函数得到输出。
# 构造一些样本,用来训练神经网络
x_data = np.linspace(-1, 1, 300)[:, np.newaxis]
# 值为(-1,1)之间的数,有300个
noise = np.random.normal(0, 0.05, x_data.shape)
x_data
array([[-1. ],
[-0.99331104],
[-0.98662207],
[-0.97993311],
[-0.97324415],
[-0.96655518],
[-0.95986622],
[-0.95317726],
[-0.94648829],
[-0.93979933],
[-0.93311037],
[-0.9264214 ],
[-0.91973244],
[-0.91304348],
[-0.90635452],
[-0.89966555],
[-0.89297659],
[-0.88628763],
[-0.87959866],
[-0.8729097 ],
[-0.86622074],
[-0.85953177],
[-0.85284281],
[-0.84615385],
[-0.83946488],
[-0.83277592],
[-0.82608696],
[-0.81939799],
[-0.81270903],
[-0.80602007],
[-0.7993311 ],
[-0.79264214],
[-0.78595318],
[-0.77926421],
[-0.77257525],
[-0.76588629],
[-0.75919732],
[-0.75250836],
[-0.7458194 ],
[-0.73913043],
[-0.73244147],
[-0.72575251],
[-0.71906355],
[-0.71237458],
[-0.70568562],
[-0.69899666],
[-0.69230769],
[-0.68561873],
[-0.67892977],
[-0.6722408 ],
[-0.66555184],
[-0.65886288],
[-0.65217391],
[-0.64548495],
[-0.63879599],
[-0.63210702],
[-0.62541806],
[-0.6187291 ],
[-0.61204013],
[-0.60535117],
[-0.59866221],
[-0.59197324],
[-0.58528428],
[-0.57859532],
[-0.57190635],
[-0.56521739],
[-0.55852843],
[-0.55183946],
[-0.5451505 ],
[-0.53846154],
[-0.53177258],
[-0.52508361],
[-0.51839465],
[-0.51170569],
[-0.50501672],
[-0.49832776],
[-0.4916388 ],
[-0.48494983],
[-0.47826087],
[-0.47157191],
[-0.46488294],
[-0.45819398],
[-0.45150502],
[-0.44481605],
[-0.43812709],
[-0.43143813],
[-0.42474916],
[-0.4180602 ],
[-0.41137124],
[-0.40468227],
[-0.39799331],
[-0.39130435],
[-0.38461538],
[-0.37792642],
[-0.37123746],
[-0.36454849],
[-0.35785953],
[-0.35117057],
[-0.34448161],
[-0.33779264],
[-0.33110368],
[-0.32441472],
[-0.31772575],
[-0.31103679],
[-0.30434783],
[-0.29765886],
[-0.2909699 ],
[-0.28428094],
[-0.27759197],
[-0.27090301],
[-0.26421405],
[-0.25752508],
[-0.25083612],
[-0.24414716],
[-0.23745819],
[-0.23076923],
[-0.22408027],
[-0.2173913 ],
[-0.21070234],
[-0.20401338],
[-0.19732441],
[-0.19063545],
[-0.18394649],
[-0.17725753],
[-0.17056856],
[-0.1638796 ],
[-0.15719064],
[-0.15050167],
[-0.14381271],
[-0.13712375],
[-0.13043478],
[-0.12374582],
[-0.11705686],
[-0.11036789],
[-0.10367893],
[-0.09698997],
[-0.090301 ],
[-0.08361204],
[-0.07692308],
[-0.07023411],
[-0.06354515],
[-0.05685619],
[-0.05016722],
[-0.04347826],
[-0.0367893 ],
[-0.03010033],
[-0.02341137],
[-0.01672241],
[-0.01003344],
[-0.00334448],
[ 0.00334448],
[ 0.01003344],
[ 0.01672241],
[ 0.02341137],
[ 0.03010033],
[ 0.0367893 ],
[ 0.04347826],
[ 0.05016722],
[ 0.05685619],
[ 0.06354515],
[ 0.07023411],
[ 0.07692308],
[ 0.08361204],
[ 0.090301 ],
[ 0.09698997],
[ 0.10367893],
[ 0.11036789],
[ 0.11705686],
[ 0.12374582],
[ 0.13043478],
[ 0.13712375],
[ 0.14381271],
[ 0.15050167],
[ 0.15719064],
[ 0.1638796 ],
[ 0.17056856],
[ 0.17725753],
[ 0.18394649],
[ 0.19063545],
[ 0.19732441],
[ 0.20401338],
[ 0.21070234],
[ 0.2173913 ],
[ 0.22408027],
[ 0.23076923],
[ 0.23745819],
[ 0.24414716],
[ 0.25083612],
[ 0.25752508],
[ 0.26421405],
[ 0.27090301],
[ 0.27759197],
[ 0.28428094],
[ 0.2909699 ],
[ 0.29765886],
[ 0.30434783],
[ 0.31103679],
[ 0.31772575],
[ 0.32441472],
[ 0.33110368],
[ 0.33779264],
[ 0.34448161],
[ 0.35117057],
[ 0.35785953],
[ 0.36454849],
[ 0.37123746],
[ 0.37792642],
[ 0.38461538],
[ 0.39130435],
[ 0.39799331],
[ 0.40468227],
[ 0.41137124],
[ 0.4180602 ],
[ 0.42474916],
[ 0.43143813],
[ 0.43812709],
[ 0.44481605],
[ 0.45150502],
[ 0.45819398],
[ 0.46488294],
[ 0.47157191],
[ 0.47826087],
[ 0.48494983],
[ 0.4916388 ],
[ 0.49832776],
[ 0.50501672],
[ 0.51170569],
[ 0.51839465],
[ 0.52508361],
[ 0.53177258],
[ 0.53846154],
[ 0.5451505 ],
[ 0.55183946],
[ 0.55852843],
[ 0.56521739],
[ 0.57190635],
[ 0.57859532],
[ 0.58528428],
[ 0.59197324],
[ 0.59866221],
[ 0.60535117],
[ 0.61204013],
[ 0.6187291 ],
[ 0.62541806],
[ 0.63210702],
[ 0.63879599],
[ 0.64548495],
[ 0.65217391],
[ 0.65886288],
[ 0.66555184],
[ 0.6722408 ],
[ 0.67892977],
[ 0.68561873],
[ 0.69230769],
[ 0.69899666],
[ 0.70568562],
[ 0.71237458],
[ 0.71906355],
[ 0.72575251],
[ 0.73244147],
[ 0.73913043],
[ 0.7458194 ],
[ 0.75250836],
[ 0.75919732],
[ 0.76588629],
[ 0.77257525],
[ 0.77926421],
[ 0.78595318],
[ 0.79264214],
[ 0.7993311 ],
[ 0.80602007],
[ 0.81270903],
[ 0.81939799],
[ 0.82608696],
[ 0.83277592],
[ 0.83946488],
[ 0.84615385],
[ 0.85284281],
[ 0.85953177],
[ 0.86622074],
[ 0.8729097 ],
[ 0.87959866],
[ 0.88628763],
[ 0.89297659],
[ 0.89966555],
[ 0.90635452],
[ 0.91304348],
[ 0.91973244],
[ 0.9264214 ],
[ 0.93311037],
[ 0.93979933],
[ 0.94648829],
[ 0.95317726],
[ 0.95986622],
[ 0.96655518],
[ 0.97324415],
[ 0.97993311],
[ 0.98662207],
[ 0.99331104],
[ 1. ]])
这段代码使用 numpy 库创建了一个一维的数组 x_data
。
代码中的参数解释如下:
-1
:数组中数的最小值。1
:数组中数的最大值。300
:数组中数的个数。[:, np.newaxis]
:对数组进行转置,转换成二维数组。函数内部逻辑:
np.linspace(-1, 1, 300)
:返回一个数值范围在 -1 到 1 之间,总共有 300 个数的等差数列。即生成一个 ndarray 数组,包含 300 个数,分布在 -1 到 1 之间。[:, np.newaxis]
:将一维数组转化成列向量,即让数组的 shape 从 (300,)
变成 (300, 1)
。最终生成的 x_data
是一个二维数组,第一维度为 300
,第二维度为 1
,表示由 300
个样本组成,每个样本只有一个特征。
# 加入噪声会更贴近真实情况,噪声的值为(0,0.05)之间,结构为x_data一样
y_data = np.square(x_data) - 0.5 + noise
# y的结构
y_data
array([[ 0.59535036],
[ 0.46017998],
[ 0.47144478],
[ 0.45083795],
[ 0.58438217],
[ 0.38570118],
[ 0.43550029],
[ 0.40597571],
[ 0.3357524 ],
[ 0.35784864],
[ 0.34530231],
[ 0.32509701],
[ 0.25554733],
[ 0.32300801],
[ 0.2299959 ],
[ 0.35472568],
[ 0.31227671],
[ 0.30385068],
[ 0.29413844],
[ 0.18437787],
[ 0.28132819],
[ 0.25605309],
[ 0.23126361],
[ 0.23492797],
[ 0.18381621],
[ 0.10392937],
[ 0.13415913],
[ 0.14043649],
[ 0.11756826],
[ 0.12142749],
[ 0.12400694],
[ 0.08926307],
[ 0.15581832],
[ 0.16541106],
[-0.02582895],
[ 0.05924725],
[-0.04037454],
[ 0.03799003],
[ 0.09030832],
[ 0.05984324],
[-0.06569464],
[ 0.07973773],
[ 0.04297837],
[ 0.05169557],
[-0.00096191],
[-0.02049573],
[-0.03125322],
[-0.04545588],
[-0.02168901],
[ 0.01657517],
[-0.04315181],
[-0.09123519],
[-0.03292835],
[-0.1110189 ],
[-0.08212792],
[-0.10089535],
[-0.17406672],
[-0.10380731],
[-0.10774072],
[-0.21283138],
[-0.09788435],
[-0.10196452],
[-0.16439081],
[-0.15431978],
[-0.17778307],
[-0.18428537],
[-0.17874028],
[-0.10490738],
[-0.25076832],
[-0.16078044],
[-0.21572183],
[-0.15624353],
[-0.19591988],
[-0.31560742],
[-0.29593726],
[-0.26686787],
[-0.2999804 ],
[-0.30631065],
[-0.35305224],
[-0.31295125],
[-0.22996255],
[-0.22837061],
[-0.27266253],
[-0.31290802],
[-0.37188479],
[-0.20765034],
[-0.33860431],
[-0.31135236],
[-0.25249981],
[-0.26041048],
[-0.31486205],
[-0.30253306],
[-0.41624795],
[-0.40053837],
[-0.29939676],
[-0.32615377],
[-0.37377787],
[-0.32222027],
[-0.3158838 ],
[-0.43880087],
[-0.37510637],
[-0.46702321],
[-0.27058091],
[-0.52885151],
[-0.4061462 ],
[-0.4486374 ],
[-0.37819628],
[-0.34701947],
[-0.32454364],
[-0.3901839 ],
[-0.43293107],
[-0.47881173],
[-0.45280819],
[-0.49676541],
[-0.48955669],
[-0.45898691],
[-0.37473462],
[-0.43801531],
[-0.44793655],
[-0.57343047],
[-0.45262969],
[-0.40719677],
[-0.45423461],
[-0.45053051],
[-0.51046881],
[-0.41584096],
[-0.53328545],
[-0.44766406],
[-0.50158463],
[-0.42676031],
[-0.50552613],
[-0.36832989],
[-0.48699296],
[-0.41614151],
[-0.6175621 ],
[-0.48304532],
[-0.46115021],
[-0.40948908],
[-0.42017024],
[-0.50411757],
[-0.44530626],
[-0.46895275],
[-0.52127771],
[-0.50064585],
[-0.42210169],
[-0.58582837],
[-0.52049198],
[-0.45332091],
[-0.53465815],
[-0.5385712 ],
[-0.5654201 ],
[-0.54471377],
[-0.48109194],
[-0.44565732],
[-0.48112022],
[-0.46471786],
[-0.5452149 ],
[-0.52115601],
[-0.50234928],
[-0.54885558],
[-0.5279981 ],
[-0.53893795],
[-0.44286416],
[-0.45371406],
[-0.44633111],
[-0.57535678],
[-0.62918947],
[-0.41877124],
[-0.56263956],
[-0.51201705],
[-0.35016007],
[-0.49188897],
[-0.55766056],
[-0.38963378],
[-0.5038024 ],
[-0.51949984],
[-0.45229896],
[-0.49193029],
[-0.53472883],
[-0.48957523],
[-0.35561181],
[-0.4622668 ],
[-0.39177781],
[-0.43448445],
[-0.49854629],
[-0.49843105],
[-0.47704375],
[-0.36618194],
[-0.45177012],
[-0.41497222],
[-0.42152064],
[-0.48996608],
[-0.43010878],
[-0.42599962],
[-0.2841197 ],
[-0.38992082],
[-0.43802592],
[-0.42448799],
[-0.29514676],
[-0.37154091],
[-0.25426219],
[-0.44610678],
[-0.37120566],
[-0.3531599 ],
[-0.34606119],
[-0.29637877],
[-0.3693284 ],
[-0.36651142],
[-0.30025118],
[-0.31443603],
[-0.40824064],
[-0.31734053],
[-0.40807378],
[-0.33792031],
[-0.22414921],
[-0.37707072],
[-0.26776417],
[-0.29152204],
[-0.34066934],
[-0.19037511],
[-0.23552614],
[-0.2144995 ],
[-0.27628531],
[-0.27329725],
[-0.23910513],
[-0.30009859],
[-0.30192088],
[-0.16403744],
[-0.32546893],
[-0.25686912],
[-0.12515146],
[-0.21483097],
[-0.12779443],
[-0.28748063],
[-0.23782354],
[-0.16024807],
[-0.19062672],
[-0.15066097],
[-0.19043274],
[-0.16583211],
[-0.11201314],
[-0.05612149],
[-0.00847256],
[-0.1429705 ],
[-0.09595988],
[-0.09583441],
[-0.01372838],
[-0.04818834],
[-0.11840653],
[ 0.02184166],
[-0.07153294],
[-0.11556547],
[-0.04731049],
[-0.10774914],
[-0.014642 ],
[-0.01470962],
[-0.03259555],
[-0.04194347],
[ 0.08987345],
[-0.02027899],
[ 0.02418433],
[ 0.04298611],
[ 0.04130101],
[ 0.18010436],
[ 0.15480307],
[ 0.02719993],
[ 0.11508363],
[ 0.04309794],
[ 0.14060578],
[ 0.09377926],
[ 0.13887198],
[ 0.16148276],
[ 0.11398259],
[ 0.27887578],
[ 0.22775177],
[ 0.20749998],
[ 0.22107721],
[ 0.20854961],
[ 0.25411644],
[ 0.26561906],
[ 0.27540788],
[ 0.26946028],
[ 0.2390275 ],
[ 0.26051795],
[ 0.34424064],
[ 0.3240088 ],
[ 0.38040554],
[ 0.35717078],
[ 0.31357911],
[ 0.43825368],
[ 0.35709739],
[ 0.48101049],
[ 0.36024364],
[ 0.43253108],
[ 0.39268334],
[ 0.41942572],
[ 0.41196584],
[ 0.54435941],
[ 0.49840622],
[ 0.51627957]])
这段代码是根据 x_data 生成对应的 y_data,并加入了一些噪声。
具体实现中,首先利用 np.square(x_data)
将 x_data
中的每个元素平方,然后减去一个常数 0.5
,最后加上一些噪声,生成与 x_data
形状相同的 y_data
数组。
由于 x_data
是一个二维数组,y_data
需要与它形状相同,因此 y_data
也是一个二维数组,包含 300 个样本和每个样本的输出值。
# 定义placeholder用来输入数据到神经网络,其中1表只有一个特征,也就是维度为一维数据
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])
# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)
# 代价函数,reduce_mean为求均值,reduce_sum为求和,reduction_indices为数据处理的维度
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),reduction_indices=[1]))
# 将代价函数传到梯度下降,学习速率为0.1,这里包含权重的训练,会更新权重
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
这段代码定义了一个神经网络模型。
首先,使用 tf.placeholder
创建两个占位符 xs
和 ys
,分别用来输入训练数据和真实标签,其中 None
表示样本数量是不确定的,只确定数据的维度是 1 维。
接下来,调用 add_layer
函数添加一个隐藏层,输入为 xs
,神经元个数为 10 个,激活函数为 ReLU
。然后再调用一次 add_layer
函数添加一个输出层,输入为隐藏层的输出,神经元个数为 1 个,激活函数为 None
(也就是不使用激活函数)。
然后,定义了一个代价函数 loss
,用来衡量预测值与真实标签之间的差距,这里选用的是 mean square error(均方误差)作为代价函数。具体实现中,使用 tf.square
计算每个样本的预测值与真实标签之间的差距,然后使用 tf.reduce_mean
计算所有样本的差距的平均值。
最后,使用 tf.train.GradientDescentOptimizer
创建一个优化器,设定学习速率为 0.1,然后调用 minimize
方法去最小化代价函数 loss
,这里会更新神经网络的权重和偏置,训练模型使得预测值与真实标签不断接近。
# important step
# tf.initialize_all_variables() no long valid from
# 2017-03-02 if using tensorflow >= 0.12
# 变量初始化
if int((tf.__version__).split('.')[1]) < 12:
init = tf.initialize_all_variables()
else:
init = tf.global_variables_initializer()
sess = tf.Session() # 打开TensorFlow
sess.run(init) # 执行变量初始化
这段代码主要是进行了 TensorFlow 的初始化操作。由于 TensorFlow 的版本问题,原来的 tf.initialize_all_variables()
不再被支持,改为了 tf.global_variables_initializer()
。然后创建了一个 tf.Session
对象 sess
,用来执行 TensorFlow 中定义的操作。
最后,执行 init
操作进行变量的初始化。这里会将之前定义的变量(包括权重和偏置)都初始化为一些随机值,用来开始训练模型。
for i in range(1000): # 梯度下降迭代一千次
# training
sess.run(train_step, feed_dict={xs: x_data, ys: y_data})
# 执行梯度下降算法,并且将样本喂给损失函数
if i % 50 == 0:
# 每50次迭代输出代价函数的值
print(sess.run(loss, feed_dict={xs: x_data, ys: y_data}))
0.18214862
0.010138167
0.0071248626
0.0069830194
0.0068635535
0.0067452225
0.006626569
0.0065121166
0.0064035906
0.006295418
0.0061897114
0.0060903295
0.005990808
0.0058959606
0.0058057955
0.0057200184
0.005637601
0.0055605737
0.0054863705
0.005413457
这段代码是用来训练模型的,实现了梯度下降的过程。循环了 1000 次,每 50 次迭代输出一次代价函数 loss
的值。在每次迭代中,通过 sess.run(train_step, feed_dict={xs: x_data, ys: y_data})
执行了一次梯度下降,并把样本传入损失函数中进行计算。
这个训练过程中的输出可以用来观察代价函数的变化情况,如果随着迭代的进行,代价函数的值逐渐减小,那么就表示模型的训练效果越来越好,模型越来越能够准确地预测目标变量。