• streamlit+ndraw进行可视化训练深度学习模型


    简介

    如果你喜欢web可视化的方式训练深度学习模型,那么streamlit是一个不可错过的选择!

    优点:

    1. 提供丰富的web组件支持
    2. 嵌入python中,简单易用
    3. 轻松构建一个web页面,按钮控制训练过程

    本文使用streamlit进行web可视化渲染,并使用ndraw进行模型可视化,做到了:

    1. 训练过程可视化
    2. 模型输入输出shape一目了然

    构建环境

    首先安装必要的依赖,tensorflow、streamlit和ndraw为必须依赖,其他依赖根据自己的情况进行安装

    pip install streamlit
    pip install tensorflow
    pip install ndraw
    
    • 1
    • 2
    • 3

    其他的环境自行安装,不过多赘述

    然后引入模块:

    import ndraw
    import streamlit as st
    import tensorflow as tf
    import streamlit.components.v1 as components
    
    • 1
    • 2
    • 3
    • 4

    编写代码

    以mnist数据集为例

    1.获取数据

    书写数据加载方法,如果你的数据集没有改动的话,可以使用@st.cache装饰器,其作用是缓存数据,不用每次训练都重新加载数据

    @st.cache(allow_output_mutation=True)
    def get_data(is_onehot = False):
        # 根据自己的训练数据进行加载
        (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
        x_train = x_train/255.0
        x_test = x_test/255.0
        if is_onehot:
            y_train = tf.one_hot(y_train,10)
            y_test = tf.one_hot(y_test,10)
        return (x_train, y_train), (x_test, y_test)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    2.构建模型

    简单构建一个模型:如果是较为复杂模型,可以使用ndraw进行维度的查看

    def build_model():
        model = tf.keras.models.Sequential([
            tf.keras.layers.Flatten(input_shape=(28, 28)),
            tf.keras.layers.Dense(128, activation='relu'),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.Dense(128, activation='relu'),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.Dense(128, activation='relu'),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.Dense(128, activation='relu'),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.Dense(10, activation='softmax')
        ])
        return model
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    3.构建逻辑

    使用streamlit构建模型的逻辑:

    1. 首先设置一个web页面的标题
    2. 在左侧设置一个导航栏:开始和结束
    3. 点击开始的时候开始训练
    4. 添加一个模型扩展位置,点击的时候可以查看模型
    if __name__ == '__main__':
        #设置网页标题
        st.title("训练xx模型")
        #点击开始后进行数据加载和训练
        if st.sidebar.button('开始'):
            (x_train, y_train), (x_test, y_test) = get_data(is_onehot=True)
    
            st.text("train size: {} {}".format(x_train.shape, y_train.shape))
            st.text("test size: {} {}".format(x_test.shape, y_test.shape))
    
            model = build_model()
            #点击查看模型后可以可视化模型
            with st.expander("查看模型"):
                components.html(ndraw.render(model,init_x=200, flow=ndraw.VERTICAL), height=1000, scrolling=True)
            model.compile(loss="categorical_crossentropy", optimizer=tf.keras.optimizers.Adam(lr=0.001),metrics=["accuracy"])
            model.fit(x_train, y_train, batch_size=128, validation_data=(x_test, y_test), epochs=10, verbose=1,callbacks=[TrainCallback(x_test, y_test)])
            st.success('训练结束')
    
        if st.sidebar.button('停止'):
            st.stop()
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21

    4.自定义指标可视化

    tf提供了丰富的自定义功能,包括模型自定义,指标自定义,loss自定义、训练过程自定义等等,此处自定义一个训练过程自定义的Callback,主要用于在训练过程中获取相关的loss和acc进行绘图

    class TrainCallback(tf.keras.callbacks.Callback):
        def __init__(self, test_x, test_y):
            super(TrainCallback, self).__init__()
            self.test_x = test_x
            self.test_y = test_y
    
        def on_train_begin(self, logs=None):
            st.header("训练汇总")
            self.summary_line = st.area_chart()
    
            st.subheader("训练进度")
            self.process_text = st.text("0/{}".format(self.params['epochs']))
            self.process_bar = st.progress(0)
    
            st.subheader('loss曲线')
            self.loss_line = st.line_chart()
    
            st.subheader('accuracy曲线')
            self.acc_line = st.line_chart()
    
        def on_epoch_end(self, epoch, logs=None):
            self.loss_line.add_rows({'train_loss': [logs['loss']], 'val_loss': [logs['val_loss']]})
            self.acc_line.add_rows({'train_acc': [logs['accuracy']], 'val_accuracy': [logs['val_accuracy']]})
            self.process_bar.progress(epoch / self.params['epochs'])
            self.process_text.empty()
            self.process_text.text("{}/{}".format(epoch, self.params['epochs']))
    
        def on_batch_end(self, epoch, logs=None):
            if epoch % 10 == 0 or epoch == self.params['epochs']:
                self.summary_line.add_rows({'loss': [logs['loss']], 'accuracy': [logs['accuracy']]})
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30

    展示

    在这里插入图片描述
    在这里插入图片描述

    总结

    以上就是整个训练过程,不同的模型只需要更改一下加载数据和构建模型的函数即可,其他内容不变或者根据自己的需求添加

    完整外码已放到gitee自取 visualneu

  • 相关阅读:
    诠释韧性增长,知乎Q3财报里的社区优势和商业化价值
    介绍VMware通过电脑本机网卡链接外部网络
    电脑电源灯一闪一闪开不了机怎么办
    以写Hbase表的方式更新Phoenix索引
    多线程篇1:java创建多线程以及线程状态
    Spring Boot 项目部署方案!打包 + Shell 脚本部署详解
    专业知识单选题练习系列(一)
    Python多任务编程
    选择边缘计算网关的五大优势
    09链表-单链表移除元素
  • 原文地址:https://blog.csdn.net/qq_21120275/article/details/127985517