上篇对Matplotlib所绘图中的各个部件进行了简单介绍,并将常见图绘制方法进行了说明,以使大家可以认识并理解相关内容,最后我对常见的两个问题的解决方法进行了实例讲解,包括:1)中文显示;2)坐标轴刻度、坐标范围自定义设置。原文链接:https://blog.csdn.net/qq_43665602/article/details/126870205,这篇文章我对plot函数的内容进行补充,如此大家在使用时拥有了更高的自由度。此外,在深度学习中大家免不了会对图像进行显示处理或者对训练数据进行可视化(比如损失函数、准确率等),这里对这部分内容也进行介绍。
之前介绍了plot函数可用来绘制折线图、曲线图等类型,通过一些基本的设置即可完成,这里我主要扩展两部分内容:
1)根据需求调整绘制目标的线条颜色、风格等;
2)使用plot函数绘制散点图(数据量大时plot函数比scatter函数效率更高);
在绘图时我们可以通过设置plot函数的一些参数,从而根据自己的想法更自由地对所绘图进行调整。其中我们可以对线条的颜色、风格类型进行指定,编写代码时有两种方式:1)分别通过指定color、linestyle参数进行设置;2)同时指定颜色和线条类型;
此处先对颜色以及线条类型的取值情况进行说明(取值类型很多,我只展示比较常用的):
线条颜色设置:
取值类型1:颜色全称 | 取值类型1:颜色全称 |
---|---|
red | r |
blue | b |
green | g |
yellow | y |
线条类型设置(依次表示实线、虚线、长短点虚线、点线:
取值类型1:类型全称 | 取值类型2:符号表示 |
---|---|
solid | - |
dashed | – |
dashdot | -. |
dotted | : |
下面对应不同的指定方式进行举例说明:
(1)自定义线条类型及颜色
默认参数情况下:
axes.plot(x,y1) # 绘图,若要在同一图中绘制多个函数,则多次调用plot函数即可
axes.plot(x,y2)
分别通过color、linestyle参数指定目标颜色及线条类型,此处y1使用红色虚线,y2使用绿色点线:
axes.plot(x,y1,color='red',linestyle='--') # 绘图,若要在同一图中绘制多个函数,则多次调用plot函数即可
axes.plot(x,y2,color='green',linestyle=':')
完整代码:
import numpy as np
from matplotlib import pyplot as plt
# x=np.linspace(0,10,num=20) # 生成等差数列
# print(x)
x=[ 0. , 0.52631579 , 1.05263158 , 1.57894737 , 2.10526316 , 2.63157895,
3.15789474 , 3.68421053 , 4.21052632 , 4.73684211 , 5.26315789 , 5.78947368,
6.31578947 , 6.84210526, 7.36842105 , 7.89473684 , 8.42105263 , 8.94736842,
9.47368421 , 10. ]
y1=[2*i for i in x] # 直线
y2=[np.sin(i) for i in x] # 正弦函数
figure,axes=plt.subplots() # 构建画板figure,并划分一块通过axes控制所有坐标轴的绘图区域
axes.plot(x,y1,color='red',linestyle='--') # 绘图,若要在同一图中绘制多个函数,则多次调用plot函数即可
axes.plot(x,y2,color='green',linestyle=':')
axes.set_title("plot example") # 设置所绘图的标题,即图名
axes.set_ylabel("function value") # 设置y轴标签,即y轴数据的含义
axes.set_xlabel("x") # 设置x轴标签,即x轴数据的含义
axes.legend(["y=2x","y=sin(x)"]) # 设置不同函数的图例
plt.show()
(2)同时指定颜色和线条类型
这种方法直接通过一个非关键字的参数进行指定,以线条类型、颜色顺序组合成一个字符串表示。此时需要注意,颜色和线条类型只能使用颜色缩写以及类型的符号表示才可以进行组合,否则会报错。
import numpy as np
from matplotlib import pyplot as plt
# x=np.linspace(0,10,num=20) # 生成等差数列
# print(x)
x=[ 0. , 0.52631579 , 1.05263158 , 1.57894737 , 2.10526316 , 2.63157895,
3.15789474 , 3.68421053 , 4.21052632 , 4.73684211 , 5.26315789 , 5.78947368,
6.31578947 , 6.84210526, 7.36842105 , 7.89473684 , 8.42105263 , 8.94736842,
9.47368421 , 10. ]
y1=[2*i for i in x] # 直线
y2=[np.sin(i) for i in x] # 正弦函数
figure,axes=plt.subplots() # 构建画板figure,并划分一块通过axes控制所有坐标轴的绘图区域
axes.plot(x,y1,'--r') # 绘图,若要在同一图中绘制多个函数,则多次调用plot函数即可
axes.plot(x,y2,':g')
axes.set_title("plot example") # 设置所绘图的标题,即图名
axes.set_ylabel("function value") # 设置y轴标签,即y轴数据的含义
axes.set_xlabel("x") # 设置x轴标签,即x轴数据的含义
axes.legend(["y=2x","y=sin(x)"]) # 设置不同函数的图例
plt.show()
之前介绍绘制散点图通常使用scatter()函数,但是scatter函数有一个很大的限制:每个点都需要独立渲染,数据量太大时则比较耗时。此处介绍使用更具灵活性的plot函数绘制散点图,这种方法进行绘图在近几年的paper中很常见,paper中各种各样的图非常醒目优雅。可以分别通过ms、mc、mec参数设置标记符号的类型、颜色以及边框颜色,下面把常见的符号类型进行说明:
标记符号 | 含义 |
---|---|
‘o’ | 实心圆 |
‘v’ | 下三角 |
‘^’ | 上三角 |
‘s’ | 正方形 |
'p[ | 五边形 |
*’ | 型号 |
‘+’ | 加号 |
‘x’ | 乘号 |
‘D’ | 菱形 |
这里我拿之前我写论文对比试验中的案例进行展示,可使用不同的符号表示不同的算法:
from matplotlib import pyplot as plt
import numpy as np
psnr_list=[32,28,40,25,42]
time_list=[2,4,5,1,6]
markers_list=['v','*','+','D','s']
legends_list=['marker=v','marker=*','marker=+','marker=D','marker=s']
figure,axes=plt.subplots()
axes.set_title('psnr-time relation')
axes.grid()
for i in range(5):
axes.plot(time_list[i],psnr_list[i],'{}'.format(markers_list[i]))
axes.set_ylabel('psnr value')
axes.set_xlabel('time')
axes.legend(legends_list)
plt.show()
Matplotlib是通过Pillow来加载图像的,他们所支持的数据类型是有区别的,Pillow只能处理uint8类型,而Matplotlib可以处理float32和uint8两种类型,但大家需要注意,Matplotlib只有图像格式为PNG时才支持float32数据类型,并将像素值归一化到[0.0,1.0],当图像格式为其他格式时Matplotlib只支持uint8类型:
# matplotlib可以处理float32、uint8,但除 PNG 以外的任何格式的图像读取/写入仅限于 uint8 数据。
# 对于 RGB 和 RGBA 图像,Matplotlib 支持 float32 和 uint8 数据类型。对于灰度,Matplotlib 仅支持 float32
from PIL import Image
import numpy as np
from matplotlib import pyplot as plt
import matplotlib.image as mimage # 需要导包
IMAGE_PATH='./iu.png'
# IMAGE_PATH='./iu.jpg'
img=mimage.imread(IMAGE_PATH)
print(type(img)) #
print(type(img[0][0][0]))
print(img[0][0][0])
PNG格式
<class 'numpy.ndarray'>
<class 'numpy.float32'>
0.24705882
JPEG格式
<class 'numpy.ndarray'>
<class 'numpy.uint8'>
63
在读取图像之前需要根据需要对图像格式进行调整,否则会出现问题:
import matplotlib.image as mimage
IMAGE_PATH='./iu2.png'
img=mimage.imread(IMAGE_PATH)
pyplot提供了imshow()函数用于图像显示:
from matplotlib import pyplot as plt
plt.imshow(img)
plt.show()
在存储图像时,这里有两种理解:
1)只存储图像数据,即图像内容;
2)将整个figure画板的内容保存下来;
两种方式分别如下:
(1)只存储图像数据:
mimage.imsave('./save.png',img)
(2)将整个figure画板的内容保存下来:
保存整个画板内容的语句一定要在plt.show()之前调用:
plt.savefig('./content.png')
plt.show()
import matplotlib.image as mimage
from matplotlib import pyplot as plt
IMAGE_PATH='./iu2.png'
img=mimage.imread(IMAGE_PATH)
plt.imshow(img)
plt.show()
mimage.imsave('./save.png',img)
在训练神经网络时我们通常会对训练过程中的一些数据进行记录、可视化,用于观察网络的收敛状况以及网络的性能,之前介绍过使用tensorboard对这些进行记录并展示,这里介绍使用matplotlib中的plot函数如何操作。我们知道plot函数所能接受的输入是numpy数组形式,或者可以表示为numpy数组形式的数据,所以我们可以在训练过程中将每轮的训练数据依次添加到列表中,然后使用plot函数绘制即可,下面我使用代码简单地模拟一下这个过程。
损失函数通常包括训练损失和验证损失两部分内容,结合二者之间的趋势可以观察网络的收敛情况:
import torch
from matplotlib import pyplot as plt
import numpy as np
from matplotlib.ticker import MultipleLocator
# 1.损失函数:训练损损失、验证损失
x_list=np.linspace(0,19,20)
# print(x_list)
train_loss_list=[]
val_loss_list=[]
for epoch in range(20):
train_loss=0
for train_batch in range(100):
loss=np.random.random() # 随机生成[0.0, 1.0)之间的数
train_loss+=loss
train_loss_list.append(train_loss/100)
val_loss = 0
for val_batch in range(60):
loss=np.random.random() + 1.0
val_loss+=loss
val_loss_list.append(val_loss/100)
figure,axes=plt.subplots()
axes.plot(x_list,train_loss_list)
axes.plot(x_list,val_loss_list)
axes.set_title("loss curve")
axes.set_ylabel("loss value")
axes.set_xlabel("epoch")
axes.xaxis.set_major_locator(MultipleLocator(1.0)) # 设置x轴刻度
axes.legend(["train_loss","val_loss"])
plt.show()
在一些分类任务中我们通常使用准确率作为一个主要的衡量指标,就测试集而言,准确率当然是越高越好,表示网络性能越好:
x_list=np.linspace(0,19,20)
val_acc_list=[]
for epoch in range(20):
val_acc = 0
for val_batch in range(60):
acc=np.random.random()
val_acc+=acc
val_acc_list.append(val_acc/60)
figure,axes=plt.subplots()
axes.plot(x_list,val_acc_list)
axes.set_title("accuracy curve")
axes.set_ylabel("accuracy value")
axes.set_xlabel("epoch")
axes.xaxis.set_major_locator(MultipleLocator(1.0))
axes.legend(["val_accuracy"])
plt.show()