• 20000字深度讲解 Python 数据可视化神器 Plotly


    本文我们按照如下3 part来深入浅出地讲解plotly的使用方法。喜欢记得收藏、关注、点赞

    • part1: 深入原理, 本文第一节和第二节,分别介绍 go和px 的设计思想和绘图原理。

    • part2: 浅出范例, 本文第三节和第四节,对比性地展示 go和px 的五种绘图范例(柱形图、折线图、散点图、热力图、直方图)

    • part3: 深入实践, 本文第五节,展示一些plotly和机器学习相结合的综合应用范例。

    注:完整代码、资料、技术交流,文末提供

    一,plotly.graph_objs绘图原理

    plotly的Figure是由data(数据,数据包括图表类型(Line,Scatter,Area,Pie)和具体数据取值信息)和 layout(布局,包括xaxis,yaxis,title,legend等) 组成的对象。

    Figure对象就像一个透明的嵌套的Python dict 一样,可以通过修改元素值而改变其形态。

    import numpy as np 
    import plotly.graph_objs as go
    
    epoches = np.arange(20)
    accs = 1-0.9/(epoches+1)
    
    data = go.Scatter(x = epoches, y=accs, mode = "lines+markers",name = "acc",
                        marker = dict(size=8,color="blue"),
                        line= dict(width=2,color="blue",dash="dash"))
    
    layout = {"title":"accuracy via epoch",
              "xaxis.title":"epoch",
              "yaxis.title":"accuracy",
              "font.size":15}
    
    fig = go.Figure(data = data,layout=layout)
    fig.show() 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    图片

    如果要把图表的颜色改成红色实线怎么办呢?很简单,我们先print(fig)一下,观察它的结构,找到线的颜色和线型的属性获取方法,然后直接对相应属性赋值就可以了。

    print(fig.data)  #如果想获取fig更详细结构信息,可以直接 fig.to_dict()
    
    
    • 1
    • 2
    (Scatter({
    'line': {'color': 'blue', 'dash': 'dash', 'width': 2},
    'marker': {'color': 'blue', 'size': 8},
    'mode': 'lines+markers',
    'name': 'acc',
    'x': array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
    18, 19]),
    'y': array([0.1       , 0.55      , 0.7       , 0.775     , 0.82      , 0.85      ,
    0.87142857, 0.8875    , 0.9       , 0.91      , 0.91818182, 0.925     ,
    0.93076923, 0.93571429, 0.94      , 0.94375   , 0.94705882, 0.95      ,
    0.95263158, 0.955     ])
    }),)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    fig.data[0].line.color = "red"
    fig.data[0].line.dash = "solid"
    fig 
    
    • 1
    • 2
    • 3

    图片

    怎么样,plotly是不是一个当之无愧的小透明。😝

    以上这种直接对一个Figure对象的属性的值的修改方法多少显得有些粗暴,不够尊重小透明。

    实际上,plotly的Figure对象提供了 fig.update_layout 和 fig.update_data 这样的方法来让 小透明面对突如其来的修改时候显得更加体面一些。

    import numpy as np 
    import plotly.graph_objs as go
    epoches = np.arange(20)
    accs = 1-0.9/(epoches+1)
    fig = go.Figure(data = go.Scatter(x = epoches, y=accs, mode = "lines+markers",name = "acc",
                        marker = dict(size=8,color="blue"),
                        line= dict(width=2,color="blue",dash="dash")))
    fig.show() 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    图片

    fig.update_traces(patch={"line.color":"red","line.dash":"solid"},selector=dict(name="acc"))
    fig.update_layout({"title":"accuracy via epoch",
              "xaxis.title":"epoch",
              "yaxis.title":"accuracy",
              "font.size":15})
    fig.show() 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    图片

    二,plotly.express绘图原理

    使用 import plotly.graph_objs as go 的go接口来绘制图表实际上已经非常简单了,一般类型的图表三五行代码就可以搞定。

    但我还是想偷懒,能否一行代码就搞定大部分常用图表呢。

    当然可以,plotly.express就是为你准备的。英文单词express 意为 快线,特快列车。就像营养快线的英文,Nutri-express.

    plotly.express的原理非常简单,Figure不是主要由 data(traces)和layout组成嘛。

    data部分传入一个pandas的DataFrame,而layout部分可以用模板template指定嘛,一行代码搞定。

    当然有时候template的一些微观形态可能与用户想要的还不完全一样,将生成的Figure当做小透明直接修改属性即可。

    import plotly.express as px 
    import numpy as np 
    import pandas as pd 
    
    dfdata = pd.DataFrame({"epoch":np.arange(20),"accuracy":1-0.9/(np.arange(20)+1)})
    fig = px.line(data_frame=dfdata,x="epoch",y="accuracy",title="accuracy via epoch")
    fig.show() 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    图片

    可以看到,plotly.express已经帮我们把坐标轴标题什么的都设置好了。

    但是整体看起来还是有些不太美观,多大的事呀,分分钟修改小透明!

    fig.update_traces(patch=dict(mode = "lines+markers",
                        marker = dict(size=8,color="blue"),
                        line= dict(width=2,color="red",dash="solid")),
                      selector=dict(type="scatter")) #用patch指定补丁,用selector指定对那个数据打补丁
    fig.update_layout({"font.size":15})
    fig.show() 
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    图片

    除了精细地修改Figure属性的话,我们想改变Figure样貌的更加快捷的方式是换一个模板(template)

    import plotly 
    print(plotly.io.templates) 
    fig.layout.template = "seaborn" 
    fig 
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    Templates configuration
    -----------------------
        Default template: 'plotly'
        Available templates:
            ['ggplot2', 'seaborn', 'simple_white', 'plotly',
             'plotly_white', 'plotly_dark', 'presentation', 'xgridoff',
             'ygridoff', 'gridon', 'none']
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    图片

    三,常用图表go绘图范例

    plotly支持的图表类型非常丰富,包括各种基础图表,统计图表,金融图表,机器学习图表,地图图表,and more.

    详情参考 https://plotly.com/python/ 中的gallery范例。

    此处只介绍最基础最常用的5种基础图表类型:柱形图、折线图、散点图、热力图、直方图。

    我们先用go接口展示绘图范例,然后作为比较,用px接口再实现一遍。

    1,柱形图

    柱形图适合表现几组数据之间的对比关系,柱形图的数据的数量一般不宜太多。

    import pandas as pd 
    import plotly.graph_objs as go
    
    x = ["f1", "f2", "f3", "f4", "f5"]
    y1 = [5, 20, 36, 10, 75]
    y2 = [10, 25, 8, 60, 20]
    
    traceA = go.Bar(x=x,y=y1,name="模型A")
    traceB = go.Bar(x=x,y=y2,name="模型B")
    layout = go.Layout(title="特征重要性分析",xaxis={"title":"特征"},
        yaxis={"title":"重要性"},barmode="group") #barmode is one of "relative","overlay","group"
    
    fig = go.Figure(data = [traceA,traceB],layout=layout)
    fig.show() 
    
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    图片

    # 互换 x轴 和 y轴含义,orientation设置为horizontal,变成水平条形图
    import pandas as pd 
    import plotly.graph_objs as go
    
    x = ["f1", "f2", "f3", "f4", "f5"]
    y1 = [5, 20, 36, 10, 75]
    y2 = [10, 25, 8, 60, 20]
    
    traceA = go.Bar(x=y1,y=x,name="模型A",orientation='h')
    traceB = go.Bar(x=y2,y=x,name="模型B",orientation='h')
    layout = go.Layout(title="特征重要性分析",xaxis={"title":"重要性"},
        yaxis={"title":"特征"},barmode="group") #barmode is one of "relative","overlay","group"
    
    fig = go.Figure(data = [traceA,traceB],layout=layout)
    fig.show() 
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    图片

    2,折线图

    折线图适合描述两个变量之间的函数关系,例如常用它来描述一个变量随时间的变化趋势。

    import pandas as pd 
    import plotly.graph_objs as go
    
    dates = ['2021-{:0>2d}'.format(s) for s in range(1,13)]
    acc = [70,72,80,65,76,80,60,67,80,90,94,82]
    recall = [65,42,35,25,67,54,34,45,38,46,64,34]
    
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=dates,y=acc,name="准确率",mode = "lines+markers"))
    fig.add_trace(go.Scatter(x=dates,y=recall,name="召回率",mode = "lines+markers"))
    fig.update_layout({"title":"线上模型表现变化趋势","xaxis.title":"月份","yaxis.title":"指标"})
    fig.update_layout({"font":{"size":15}})
    fig.show() 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    图片

    #以上图表中的x轴刻度被自动换成了英文时间,不是很方便识别,使用如下设置直接指定刻度位置和刻度显示内容。
    fig.update_layout(
        xaxis = dict(tickmode='array', 
                     tickvals = x,    
                     ticktext = x,     
                     tickangle = 60
        )
    )
    fig.show() 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    图片

    3,散点图

    散点图适合表现大量样本的多个属性的分布规律。散点图的每个点表示一个样本,每个坐标维度表示一个属性。

    当样本属性维度多于2个时,可以使用点的颜色或大小等方式来表达更多属性维度。

    import pandas as pd 
    import plotly.graph_objs as go
    
    dfboy = pd.DataFrame()
    dfboy['weight'] = [56,67,65,70,57,60,80,85,76,64]
    dfboy['height'] = [162,170,168,172,168,172,180,176,178,170]
    dfboy["BMI"] =  dfboy["weight"]/(dfboy["height"]**2)
    
    dfgirl = pd.DataFrame()
    dfgirl['weight'] = [50,62,60,70,57,45,62,65,70,56]
    dfgirl['height'] = [155,162,165,170,166,158,160,170,172,165]
    dfgirl["gender"] = "female"
    dfgirl["BMI"] = dfgirl["weight"]/(dfgirl["height"]**2)
    
    
    trace1 = go.Scatter(x=dfboy["weight"],y=dfboy["height"],mode="markers",name="male",
                         marker = dict(color="blue",size=3e5*dfboy["BMI"],sizemode='area'))
    
    trace2 = go.Scatter(x=dfgirl["weight"],y=dfgirl["height"],mode="markers",name="female",
                       marker = dict(color="red",size=3e5*dfgirl["BMI"],
                                     sizemode='area'))
    
    layout = go.Layout({"title":"height & weight",
                        "xaxis.title":"weight",
                        "yaxis.title":"height",
                        "legend.title":"gender",
                        "font.size":15})
    
    fig = go.Figure(data=[trace1,trace2],layout=layout)
    fig.show() 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30

    图片

    4,热力图

    热力图可以直观地展示一个二维矩阵的取值,它将一个矩阵的每个元素取值对应到热力图上的一个像素颜色取值。

    import numpy as np 
    import plotly.graph_objs as go
    
    arr = np.random.normal(loc = 0,scale = 1,size = [10,10])
    trace = go.Heatmap(x=np.arange(10),y=np.arange(10),z=arr,
                       colorscale='Viridis',showscale=True,reversescale = False)
    layout = go.Layout(width=600, height=600)
    fig = go.Figure(data=trace,layout=layout) 
    fig.show() 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    图片

    5,直方图

    直方图适合呈现一组数据的统计分布规律,它计算这组数据落在各个小的分段区间的样本个数并用类似柱状图的方式展示出来。

    import numpy as np 
    import plotly.graph_objs as go 
    scores = np.random.randint(low=0,high=100,size = 1000)
    
    trace = go.Histogram(x=scores,histnorm = 'density',nbinsx=60)
    fig = go.Figure(trace)
    fig.update_layout({"title":"Score Distribution","xaxis.title":"score","yaxis.title":"frequency","template":"seaborn"})
    fig.show() 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    图片

    四,常用图表px绘图范例

    作为对比,下面使用plotly.express接口绘制5种最常用的基础图表:

    柱形图、折线图、散点图、热力图、直方图。

    1,柱形图

    柱形图适合表现几组数据之间的对比关系,柱形图的数据的数量一般不宜太多。

    import pandas as pd 
    import plotly.express as px 
    
    x = ["f1", "f2", "f3", "f4", "f5"]
    y1 = [5, 20, 36, 10, 75]
    y2 = [10, 25, 8, 60, 20]
    df=pd.DataFrame({"特征": x, "模型A": y1, "模型B": y2})
    
    fig = px.bar(data_frame= df, x = "特征",y= ["模型A","模型B"],
           title = "特征重要性分析",barmode = "group") #barmode is one of "relative","overlay","group"
    fig.layout.yaxis.title = "重要性"
    fig.show() 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    # 互换 x轴 和 y轴含义,变成水平条形图
    import pandas as pd 
    import plotly.express as px 
    
    x = ["f1", "f2", "f3", "f4", "f5"]
    y1 = [5, 20, 36, 10, 75]
    y2 = [10, 25, 8, 60, 20]
    df=pd.DataFrame({"特征": x, "模型A": y1, "模型B": y2})
    
    fig = px.bar(data_frame= df, x = ["模型A","模型B"], y= "特征",
           title = "特征重要性分析",barmode = "relative") #barmode is one of "relative","overlay","group"
    fig.layout.xaxis.title = "重要性"
    fig.show() 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    图片

    2,折线图

    折线图适合描述两个变量之间的函数关系,例如常用它来描述一个变量随时间的变化趋势。

    import pandas as pd 
    import plotly.express as px 
    
    x = ['2021-{:0>2d}'.format(s) for s in range(1,13)]
    y1 = [70,72,80,65,76,80,60,67,80,90,94,82]
    y2 = [65,42,35,25,67,54,34,45,38,46,64,34]
    
    dfdata = {"月份":x, "召回率": y1, "准确率": y2}
    
    fig = px.line(data_frame=dfdata, x="月份", y = ["召回率","准确率"], title ="线上模型表现变化趋势")
    fig.layout.yaxis.title = "指标"
    fig.update_layout({"font":{"size":15}})
    fig.show() 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    以上图表中的x轴刻度被自动换成了英文时间,不是很方便识别,使用如下设置直接指定刻度位置和刻度显示内容。

    fig.update_layout(
        xaxis = dict(tickmode='array', 
                     tickvals = x,    
                     ticktext = x,     
                     tickangle = 60
        )
    )
    fig
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    图片

    3,散点图

    散点图适合表现大量样本的多个属性的分布规律。散点图的每个点表示一个样本,每个坐标维度表示一个属性。

    当样本属性维度多于2个时,可以使用点的颜色或大小等方式来表达更多属性维度。

    import pandas as pd 
    import plotly.express as px 
    
    dfboy = pd.DataFrame()
    dfboy['weight'] = [56,67,65,70,57,60,80,85,76,64]
    dfboy['height'] = [162,170,168,172,168,172,180,176,178,170]
    dfboy["gender"] = "male"
    
    dfgirl = pd.DataFrame()
    dfgirl['weight'] = [50,62,60,70,57,45,62,65,70,56]
    dfgirl['height'] = [155,162,165,170,166,158,160,170,172,165]
    dfgirl["gender"] = "female"
    
    dftotal = pd.concat([dfboy,dfgirl])
    dftotal["BMI"] = dftotal["weight"]/(dftotal["height"]**2)
    
    fig = px.scatter(data_frame=dftotal,x="weight",y="height",color="gender",size = "BMI",
                     color_discrete_map = {"male":"blue","female":"red"},
                     title="height & weight")
    fig.update_layout({"font":{"size":15}})
    
    fig
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22

    图片

    4,热力图

    热力图可以直观地展示一个二维矩阵的取值,它将一个矩阵的每个元素取值对应到热力图上的一个像素颜色取值。

    import numpy as np 
    import plotly.express as px 
    
    arr = np.random.normal(loc = 0,scale = 1,size = [10,10])
    px.imshow(arr,color_continuous_scale="blues")
    
    • 1
    • 2
    • 3
    • 4
    • 5

    图片

    5,直方图

    直方图适合呈现一组数据的统计分布规律,它计算这组数据落在各个小的分段区间的样本个数并用类似柱状图的方式展示出来。

    import numpy as np 
    import plotly.express as px 
    scores = np.random.randint(low=0,high=100,size = 1000)
    fig = px.histogram(x = scores,histnorm = 'density',nbins=60)
    fig.update_layout({"xaxis.title":"score"})
    fig 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    import plotly 
    plotly.io.write_html(fig,"score_distribution.html")
    
    • 1
    • 2

    图片

    五,在机器学习中应用plotly

    本例将使用plotly辅助进行catboost二分类建模的一些可视化分析。

    from IPython.display import display 
    
    import datetime,json
    import numpy as np
    import pandas as pd
    import catboost as cb 
    from sklearn import datasets
    from sklearn.model_selection import train_test_split
    from sklearn.model_selection import StratifiedKFold
    
    from sklearn.metrics import f1_score,roc_auc_score,roc_curve,accuracy_score,precision_recall_curve,auc
    import plotly.graph_objs as go 
    import plotly.express as px 
    from plotly.subplots import make_subplots 
    
    
    def printlog(info):
        nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        print("\n"+"=========="*8 + "%s"%nowtime)
        print(info+'...\n\n')
         
    #================================================================================
    # 一,准备数据
    #================================================================================
    printlog("step1: preparing data...")
    
    dfdata = pd.read_csv("../data/titanic/train.csv")
    dftest = pd.read_csv("../data/titanic/test.csv")
    
    
    label_col = "Survived"
    
    # 填充空值特征
    dfnull = pd.DataFrame(dfdata.isnull().sum(axis=0),columns = ["null_cnt"]).query("null_cnt>0")
    
    dfdata.fillna(-9999, inplace=True)
    dftest.fillna(-9999, inplace=True)
    
    
    # 刷选类别特征
    cate_cols = [x for x in dfdata.columns 
                 if dfdata[x].dtype not in [np.float32,np.float64] and x!=label_col]
    for col in cate_cols:
        dfdata[col] = pd.Categorical(dfdata[col]) 
        dftest[col] = pd.Categorical(dftest[col]) 
    
    # 分割数据集
    dftrain,dfvalid = train_test_split(dfdata, train_size=0.75, random_state=42)
    Xtrain,Ytrain = dftrain.drop(label_col,axis = 1),dftrain[label_col]
    Xvalid,Yvalid = dfvalid.drop(label_col,axis = 1),dfvalid[label_col]
    
    
    # 整理成Pool
    pool_train = cb.Pool(data = Xtrain, label = Ytrain, cat_features=cate_cols)
    pool_valid = cb.Pool(data = Xvalid, label = Yvalid, cat_features=cate_cols)
    
    
    #================================================================================
    # 二,设置参数
    #================================================================================
    printlog("step2: setting parameters...")
                                   
    iterations = 1000
    early_stopping_rounds = 200
    
    params = {
        'learning_rate': 0.05,
        'loss_function': cb.metrics.Logloss(),
        'eval_metric': "AUC",
        'depth': 6,
        'min_data_in_leaf': 20,
        'random_seed': 42,
        'logging_level': 'Silent',
        'use_best_model': True,
        'boosting_type':"Ordered",
        'nan_mode': 'Min'
    }
    
    
    #================================================================================
    # 三,训练模型
    #================================================================================
    printlog("step3: training model...")
    
    
    model = cb.CatBoostClassifier(
        iterations = iterations,
        early_stopping_rounds = early_stopping_rounds,
        train_dir='catboost_info/',
        **params
    )
    
    
    model.fit(
        pool_train,
        eval_set=pool_valid,
        plot=True
    )
    
    
    #================================================================================
    # 四,评估模型
    #================================================================================
    printlog("step4: evaluating model ...")
    
    
    y_pred_train = model.predict(Xtrain)
    y_pred_valid = model.predict(Xvalid)
    
    train_score = f1_score(Ytrain,y_pred_train)
    valid_score = f1_score(Yvalid,y_pred_valid)
    
    
    print('train f1_score: {:.5} '.format(train_score))
    print('valid f1_score: {:.5} \n'.format(valid_score))   
    
    
    #feature importance 
    dfimportance = model.get_feature_importance(prettified=True) 
    dfimportance = dfimportance.sort_values(by = "Importances").iloc[-20:]
    fig_importance = px.bar(dfimportance,x="Importances",y="Feature Id",title="Feature Importance")
    
    fig_importance.show() 
    
    
    #score distribution
    y_test_prob = model.predict_proba(dftest.drop(label_col,axis = 1))[:,-1]
    fig_hist = px.histogram(
        x=y_test_prob,color =dftest[label_col],  nbins=50,
        title = "Score Distribution",
        labels=dict(color='True Labels', x='Score')
    )
    fig_hist.show() 
    
    
    #ROC-AUC & PR-AUC
    fpr, tpr, thresholds_roc = roc_curve(dftest[label_col], y_test_prob)
    precision, recall, thresholds_pr = precision_recall_curve(dftest[label_col], y_test_prob)
    
    fig = make_subplots(rows=1, cols=2,horizontal_spacing=0.1,vertical_spacing=0.1,
                        start_cell= 'top-left', # 'bottom-left' 'bottom-left',
                        subplot_titles=[
                        f'ROC Curve (ROC-AUC={auc(fpr, tpr):.4f})',
                        f'PR Curve (PR-AUC={auc(recall, precision):.4f})',] 
                       )
    #ROC-curve
    fig.add_trace(go.Scatter(x=fpr,y=tpr,mode='lines',stackgroup= '1',name="roc_curve"),row=1,col=1)
    fig.add_shape(type='line', line=dict(dash='dash'),x0=0, x1=1, y0=0, y1=1,row=1,col=1)
    fig.update_xaxes(title_text="False Positive Rate", row=1, col=1)
    fig.update_yaxes(title_text="True Positive Rate", row=1, col=1)
    
    #PR-curve
    fig.add_trace(go.Scatter(x=recall,y=precision,mode='lines',stackgroup= '1',name="pr_curve"),row=1,col=2)
    fig.add_shape(type='line', line=dict(dash='dash'),x0=0, x1=1, y0=1, y1=0,row=1,col=2)
    fig.update_xaxes(title_text="Recall", row=1, col=2)
    fig.update_yaxes(title_text="Precision", row=1, col=2)
    
    fig.update_layout({"height":500,"width":1000,"showlegend":False})
    fig.show() 
    
    
    #================================================================================
    # 五,使用模型
    #================================================================================
    printlog("step5: using model ...")
    
    y_pred_test = model.predict(dftest.drop(label_col,axis = 1))
    y_pred_test_prob = model.predict_proba(dftest.drop(label_col,axis = 1))
    
    
    
    #================================================================================
    # 六,保存模型
    #================================================================================
    printlog("step6: saving model ...")
    
    model_dir = 'catboost_model'
    model.save_model(model_dir)
    model_loaded = cb.CatBoostClassifier()
    model.load_model(model_dir)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180

    图片

    图片

    图片

    以上。

    技术交流

    欢迎转载、收藏、有所收获点赞支持一下!数据、代码可以找我获取

    目前开通了技术交流群,群友已超过2000人,添加时最好的备注方式为:来源+兴趣方向,方便找到志同道合的朋友

    • 方式①、发送如下图片至微信,长按识别,后台回复:加群;
    • 方式②、添加vx:dkl88191,备注:来自CSDN

    在这里插入图片描述

  • 相关阅读:
    kafka伪集群部署,使用docker环境拷贝模式
    Unity 利用Cache实现边下边玩
    【前端小点】ElementUI-Dialog标题添加图标
    LeetCode 每日一题——623. 在二叉树中增加一行
    新手初学课,Python入门体验之九九乘法表
    机器学习:一文从入门到读懂PCA(主成分分析)
    细说tcpdump的妙用
    华为OD机试真题2022Q4 A + 2023 B卷(Java&JavaScript)
    threejs视频教程学习(5):水天一色小岛
    『PyQt5-Qt Designer篇』| 13 Qt Designer中如何给工具添加菜单和工具栏?
  • 原文地址:https://blog.csdn.net/qq_34160248/article/details/126555138