• Matplotlib补充:科研绘图利器(写论文、数据可视化必备)


    前言

    上篇对Matplotlib所绘图中的各个部件进行了简单介绍,并将常见图绘制方法进行了说明,以使大家可以认识并理解相关内容,最后我对常见的两个问题的解决方法进行了实例讲解,包括:1)中文显示;2)坐标轴刻度、坐标范围自定义设置。原文链接:https://blog.csdn.net/qq_43665602/article/details/126870205,这篇文章我对plot函数的内容进行补充,如此大家在使用时拥有了更高的自由度。此外,在深度学习中大家免不了会对图像进行显示处理或者对训练数据进行可视化(比如损失函数、准确率等),这里对这部分内容也进行介绍。

    一、关于plot函数

    之前介绍了plot函数可用来绘制折线图、曲线图等类型,通过一些基本的设置即可完成,这里我主要扩展两部分内容:
    1)根据需求调整绘制目标的线条颜色、风格等;
    2)使用plot函数绘制散点图(数据量大时plot函数比scatter函数效率更高);

    1.折线图/曲线图扩充内容

    在绘图时我们可以通过设置plot函数的一些参数,从而根据自己的想法更自由地对所绘图进行调整。其中我们可以对线条的颜色、风格类型进行指定,编写代码时有两种方式:1)分别通过指定color、linestyle参数进行设置;2)同时指定颜色和线条类型;
    此处先对颜色以及线条类型的取值情况进行说明(取值类型很多,我只展示比较常用的):
    线条颜色设置:

    取值类型1:颜色全称取值类型1:颜色全称
    redr
    blueb
    greeng
    yellowy

    线条类型设置(依次表示实线、虚线、长短点虚线、点线:

    取值类型1:类型全称取值类型2:符号表示
    solid-
    dashed
    dashdot-.
    dotted:

    下面对应不同的指定方式进行举例说明:
    (1)自定义线条类型及颜色
    默认参数情况下:

    axes.plot(x,y1)  # 绘图,若要在同一图中绘制多个函数,则多次调用plot函数即可
    axes.plot(x,y2)
    
    • 1
    • 2

    在这里插入图片描述

    分别通过color、linestyle参数指定目标颜色及线条类型,此处y1使用红色虚线,y2使用绿色点线:

    axes.plot(x,y1,color='red',linestyle='--')  # 绘图,若要在同一图中绘制多个函数,则多次调用plot函数即可
    axes.plot(x,y2,color='green',linestyle=':')
    
    • 1
    • 2

    在这里插入图片描述
    完整代码:

    import numpy as np
    from matplotlib import pyplot as plt
    
    
    # x=np.linspace(0,10,num=20)  # 生成等差数列
    # print(x)
    x=[ 0.         , 0.52631579 ,  1.05263158 ,  1.57894737  , 2.10526316  , 2.63157895,
      3.15789474 ,  3.68421053 ,  4.21052632 ,  4.73684211  , 5.26315789 ,  5.78947368,
      6.31578947  , 6.84210526,  7.36842105  , 7.89473684 ,  8.42105263 ,  8.94736842,
      9.47368421 , 10.        ]
    y1=[2*i for i in x]  # 直线
    y2=[np.sin(i) for i in x]    # 正弦函数
    figure,axes=plt.subplots()  # 构建画板figure,并划分一块通过axes控制所有坐标轴的绘图区域
    axes.plot(x,y1,color='red',linestyle='--')  # 绘图,若要在同一图中绘制多个函数,则多次调用plot函数即可
    axes.plot(x,y2,color='green',linestyle=':')
    axes.set_title("plot example")  # 设置所绘图的标题,即图名
    axes.set_ylabel("function value")  # 设置y轴标签,即y轴数据的含义
    axes.set_xlabel("x")  # 设置x轴标签,即x轴数据的含义
    axes.legend(["y=2x","y=sin(x)"])  # 设置不同函数的图例
    plt.show()
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21

    (2)同时指定颜色和线条类型
    这种方法直接通过一个非关键字的参数进行指定,以线条类型、颜色顺序组合成一个字符串表示。此时需要注意,颜色和线条类型只能使用颜色缩写以及类型的符号表示才可以进行组合,否则会报错。

    import numpy as np
    from matplotlib import pyplot as plt
    
    
    # x=np.linspace(0,10,num=20)  # 生成等差数列
    # print(x)
    x=[ 0.         , 0.52631579 ,  1.05263158 ,  1.57894737  , 2.10526316  , 2.63157895,
      3.15789474 ,  3.68421053 ,  4.21052632 ,  4.73684211  , 5.26315789 ,  5.78947368,
      6.31578947  , 6.84210526,  7.36842105  , 7.89473684 ,  8.42105263 ,  8.94736842,
      9.47368421 , 10.        ]
    y1=[2*i for i in x]  # 直线
    y2=[np.sin(i) for i in x]    # 正弦函数
    figure,axes=plt.subplots()  # 构建画板figure,并划分一块通过axes控制所有坐标轴的绘图区域
    axes.plot(x,y1,'--r')  # 绘图,若要在同一图中绘制多个函数,则多次调用plot函数即可
    axes.plot(x,y2,':g')
    axes.set_title("plot example")  # 设置所绘图的标题,即图名
    axes.set_ylabel("function value")  # 设置y轴标签,即y轴数据的含义
    axes.set_xlabel("x")  # 设置x轴标签,即x轴数据的含义
    axes.legend(["y=2x","y=sin(x)"])  # 设置不同函数的图例
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20

    在这里插入图片描述

    2.使用plot函数绘制散点图

    之前介绍绘制散点图通常使用scatter()函数,但是scatter函数有一个很大的限制:每个点都需要独立渲染,数据量太大时则比较耗时。此处介绍使用更具灵活性的plot函数绘制散点图,这种方法进行绘图在近几年的paper中很常见,paper中各种各样的图非常醒目优雅。可以分别通过ms、mc、mec参数设置标记符号的类型、颜色以及边框颜色,下面把常见的符号类型进行说明:

    标记符号含义
    ‘o’实心圆
    ‘v’下三角
    ‘^’上三角
    ‘s’正方形
    'p[五边形
    *’型号
    ‘+’加号
    ‘x’乘号
    ‘D’菱形

    这里我拿之前我写论文对比试验中的案例进行展示,可使用不同的符号表示不同的算法:

    from matplotlib import pyplot as plt
    import numpy as np
    
    
    psnr_list=[32,28,40,25,42]
    time_list=[2,4,5,1,6]
    markers_list=['v','*','+','D','s']
    legends_list=['marker=v','marker=*','marker=+','marker=D','marker=s']
    
    
    figure,axes=plt.subplots()
    axes.set_title('psnr-time relation')
    axes.grid()
    for i in range(5):
        axes.plot(time_list[i],psnr_list[i],'{}'.format(markers_list[i]))
    axes.set_ylabel('psnr value')
    axes.set_xlabel('time')
    axes.legend(legends_list)
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19

    在这里插入图片描述

    二、图像的读取、显示、存储

    Matplotlib是通过Pillow来加载图像的,他们所支持的数据类型是有区别的,Pillow只能处理uint8类型,而Matplotlib可以处理float32和uint8两种类型,但大家需要注意,Matplotlib只有图像格式为PNG时才支持float32数据类型,并将像素值归一化到[0.0,1.0],当图像格式为其他格式时Matplotlib只支持uint8类型

    # matplotlib可以处理float32、uint8,但除 PNG 以外的任何格式的图像读取/写入仅限于 uint8 数据。
    # 对于 RGB 和 RGBA 图像,Matplotlib 支持 float32 和 uint8 数据类型。对于灰度,Matplotlib 仅支持 float32
    from PIL import Image
    import numpy as np
    from matplotlib import pyplot as plt
    import matplotlib.image as  mimage  # 需要导包
    
    
    IMAGE_PATH='./iu.png'
    # IMAGE_PATH='./iu.jpg'
    
    img=mimage.imread(IMAGE_PATH)
    print(type(img))  # 
    print(type(img[0][0][0]))
    print(img[0][0][0])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    PNG格式
    <class 'numpy.ndarray'>
    <class 'numpy.float32'>
    0.24705882
    
    • 1
    • 2
    • 3
    • 4
    JPEG格式
    <class 'numpy.ndarray'>
    <class 'numpy.uint8'>
    63
    
    • 1
    • 2
    • 3
    • 4

    1.读取图像

    在读取图像之前需要根据需要对图像格式进行调整,否则会出现问题:

    import matplotlib.image as  mimage
    
    
    IMAGE_PATH='./iu2.png'
    img=mimage.imread(IMAGE_PATH)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    2.显示图像

    pyplot提供了imshow()函数用于图像显示:

    from matplotlib import pyplot as plt
    
    plt.imshow(img)
    plt.show()
    
    • 1
    • 2
    • 3
    • 4

    在这里插入图片描述

    3.存储图像

    在存储图像时,这里有两种理解:
    1)只存储图像数据,即图像内容;
    2)将整个figure画板的内容保存下来;
    两种方式分别如下:
    (1)只存储图像数据:

    mimage.imsave('./save.png',img)
    
    • 1

    在这里插入图片描述
    (2)将整个figure画板的内容保存下来:
    保存整个画板内容的语句一定要在plt.show()之前调用:

    plt.savefig('./content.png')
    
    plt.show()
    
    • 1
    • 2
    • 3

    在这里插入图片描述

    4.案例完整代码

    import matplotlib.image as  mimage
    from matplotlib import pyplot as plt
    
    
    IMAGE_PATH='./iu2.png'
    img=mimage.imread(IMAGE_PATH)
    
    plt.imshow(img)
    plt.show()
    
    mimage.imsave('./save.png',img)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    三、训练损失、准确率

    在训练神经网络时我们通常会对训练过程中的一些数据进行记录、可视化,用于观察网络的收敛状况以及网络的性能,之前介绍过使用tensorboard对这些进行记录并展示,这里介绍使用matplotlib中的plot函数如何操作。我们知道plot函数所能接受的输入是numpy数组形式,或者可以表示为numpy数组形式的数据,所以我们可以在训练过程中将每轮的训练数据依次添加到列表中,然后使用plot函数绘制即可,下面我使用代码简单地模拟一下这个过程。

    1.损失函数

    损失函数通常包括训练损失和验证损失两部分内容,结合二者之间的趋势可以观察网络的收敛情况:

    import torch
    from matplotlib import pyplot as plt
    import numpy as np
    from matplotlib.ticker import MultipleLocator
    
    
    # 1.损失函数:训练损损失、验证损失
    x_list=np.linspace(0,19,20)
    # print(x_list)
    train_loss_list=[]
    val_loss_list=[]
    for epoch in range(20):
        train_loss=0
        for train_batch in range(100):
            loss=np.random.random()  # 随机生成[0.0, 1.0)之间的数
            train_loss+=loss
        train_loss_list.append(train_loss/100)
    
        val_loss = 0
        for val_batch in range(60):
            loss=np.random.random() + 1.0
            val_loss+=loss
        val_loss_list.append(val_loss/100)
    
    figure,axes=plt.subplots()
    axes.plot(x_list,train_loss_list)
    axes.plot(x_list,val_loss_list)
    axes.set_title("loss curve")
    axes.set_ylabel("loss value")
    axes.set_xlabel("epoch")
    axes.xaxis.set_major_locator(MultipleLocator(1.0))  # 设置x轴刻度
    axes.legend(["train_loss","val_loss"])
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33

    在这里插入图片描述

    2.准确率

    在一些分类任务中我们通常使用准确率作为一个主要的衡量指标,就测试集而言,准确率当然是越高越好,表示网络性能越好:

    x_list=np.linspace(0,19,20)
    val_acc_list=[]
    for epoch in range(20):
        val_acc = 0
        for val_batch in range(60):
            acc=np.random.random()
            val_acc+=acc
        val_acc_list.append(val_acc/60)
    
    figure,axes=plt.subplots()
    axes.plot(x_list,val_acc_list)
    axes.set_title("accuracy curve")
    axes.set_ylabel("accuracy value")
    axes.set_xlabel("epoch")
    axes.xaxis.set_major_locator(MultipleLocator(1.0))
    axes.legend(["val_accuracy"])
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    在这里插入图片描述

  • 相关阅读:
    O-Star|再相识
    【每日刷题】Day72
    C++ primer plus--输入、输出和文件
    prompt learning 术语中文翻译
    06、HSMS协议介绍
    京东云开发者|IoT运维 - 如何部署一套高可用K8S集群
    软考高项-配置管理
    【VS Code】使用 VS Code 登陆远程服务器上的 Docker 容器
    东方博易OJ——1005 - 【入门】已知一个圆的半径,求解该圆的面积和周长
    Vue源码:手写patch函数,diff算法
  • 原文地址:https://blog.csdn.net/qq_43665602/article/details/126912306