在训练深度学习网络时,监控训练进度通常很有用。通过在训练过程中绘制各种指标,您可以了解训练的进度情况。例如,您可以确定网络准确度是否改善以及改善速度,还可以确定网络是否开始过拟合训练数据。
在 trainingOptions 中将 ‘training-progress’ 指定为 ‘Plots’ 值并开始网络训练时,trainNetwork 会创建一个图窗并在每次迭代时显示训练指标。 每次迭代都是对梯度的一次估计和对网络参数的一次更新。如果在 trainingOptions 中指定验证数据,则每次 trainNetwork 验证网络时,该图窗都会显示验证指标。该图窗绘制以下内容:
训练准确度 - 针对每个小批量的分类准确度。
经过平滑处理的训练准确度 - 经过平滑处理的训练准确度,通过将平滑算法应用于训练准确度来获得。它的噪声低于未平滑的准确度,因此更易于揭示趋势。
验证准确度 - 针对整个验证集的分类准确度(使用 trainingOptions 指定)。
训练损失、经过平滑处理的训练损失和验证损失 - 分别指每个小批量的损失、其经过平滑处理的版本以及验证集的损失。如果网络的最终层是一个 classificationLayer,则损失函数是交叉熵损失。有关分类和回归问题的损失函数的详细信息,请参阅输出层。
对于回归网络,该图窗绘制均方根误差 (RMSE) 而不是准确度。
图窗使用交替底色来标记每一轮训练。一轮训练是对整个训练数据集的一次完整遍历。
在训练过程中,您可以通过点击右上角的停止按钮停止训练并返回网络的当前状态。例如,您可能希望在网络准确度达到稳定水平并且准确度明显不再提高时停止训练。点击停止按钮后,可能需要一段时间才能完成训练。训练完成后,trainNetwork 将返回经过训练的网络。
训练结束后,查看结果,其中显示最终验证准确度和训练结束的原因。最终验证指标在绘图中标注为 Final。如果您的网络包含批量归一化层,则最终验证指标可以与训练过程中评估出的验证指标不同。这是因为在训练完成后,用于批量归一化的均值和方差统计量可能会有所不同。例如,如果 ‘BatchNormalizationStatisics’ 训练选项为 ‘population’,则在训练后,软件通过再次使训练数据穿过来完成批量归一化统计,并使用得到的均值和方差。如果 ‘BatchNormalizationStatisics’ 训练选项为 ‘moving’,则软件在训练过程中使用运行估计来逼近统计量,并使用统计量的最新值。
在训练过程中绘制训练进度
训练网络并在训练过程中绘制训练进度。
加载训练数据,其中包含 5000 个数字图像。留出 1000 个图像用于网络验证。
[XTrain,YTrain] = digitTrain4DArrayData;
idx = randperm(size(XTrain,4),1000);
XValidation = XTrain(:,:,:,idx);
XTrain(:,:,:,idx) = [];
YValidation = YTrain(idx);
YTrain(idx) = [];
构建网络以对数字图像数据进行分类。
layers = [
imageInputLayer([28 28 1])
convolution2dLayer(3,8,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,16,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,32,'Padding','same')
batchNormalizationLayer
reluLayer
fullyConnectedLayer(10)
softmaxLayer
classificationLayer];
指定网络训练的选项。要在训练过程中按固定时间间隔验证网络,请指定验证数据。选择 ‘ValidationFrequency’ 值,以使网络大致在每轮训练都被验证一次。要在训练过程中绘制训练进度,请将 ‘training-progress’ 指定为 ‘Plots’ 值。
options = trainingOptions('sgdm', ...
'MaxEpochs',8, ...
'ValidationData',{XValidation,YValidation}, ...
'ValidationFrequency',30, ...
'Verbose',false, ...
'Plots','training-progress');
训练网络。
net = trainNetwork(XTrain,YTrain,layers,options);