• Tensorflow 模型保存、节点修改以及Serving 图优化


    Tensorflow 模型保存、节点修改以及Serving 图优化

    前言 (与正文无关, 可忽略)

    近期打算总结一些 Tensorflow 的基础知识, 方便查阅. 本文的写作动机是考虑到一个小问题: 我们常用 tf.data 系列 API 来生成训练数据, 因此 Train Graph 的输入节点通常是 Iterator 节点 (比如会调用 tf.data.make_one_shot_iterator 以及该对象的 get_next() 方法), 但是在 Serving 的时候, 我在想应该如何处理输入节点, 如何把新增的 tf.placeholder 加入到 Serving 图中.

    一种方法是将 Serving Graph 重新写一遍, 输入节点更新成 tf.placeholder, 然后输入到模型中, 从而生成一个新的 Graph; 但我希望有更简洁的方法, 比如能不能直接将 Iterator 输入节点替换成 tf.placeholder, 这样即便我不知道模型代码是如何写的, 也能构建好 Serving 图. 在该问题的指引下, 对 TF 模型的保存与加载, Graph/MetaGraph 等概念有了稍微深入的了解.

    总览

    本文介绍 Tensorflow 模型部分保存方式, 主要包含 checkpoint 格式、frozen_graph 格式(SavedModel 格式暂略), 通过代码实例了解模型的保存方式, Serving 图的优化以及对 Serving 图中的节点进行修改更新.

    代码地址

    本文代码在 Python 3.5.2 | Tensorflow 1.15.0 环境下测试成功.

    本文所有代码均可以从 https://github.com/axzml/BlogShare/tree/master/Tensorflow/GraphDef 下载.

    广而告之

    可以在微信中搜索 “珍妮的算法之路” 或者 “world4458” 关注我的微信公众号;另外可以看看知乎专栏 PoorMemory-机器学习, 以后文章也会发在知乎专栏中.

    checkpoint 格式

    训练代码 & 保存 ckpt

    写了一个简单的训练代码(train.py)如下, 五脏俱全, 其中定义了三个主要函数:

    • data_generator() : 生成 Fake 数据参与模型训练
    • model() : 定义了简单的神经网络
    • train() : 定义训练代码, 调用 tf.train.Saver() 以 checkpoint 的形式保存模型
    # _*_ coding:utf-8 _*_
    ## train.py
    import tensorflow as tf
    import os
    import numpy as np
    from os.path import join, exists
    
    batch_size = 2
    steps = 10
    epochs = 1
    emb_dim = 4
    sample_num = epochs * steps * batch_size
    
    checkpoint_dir = 'checkpoint_dir'
    meta_name = '0'
    saver_dir = join(checkpoint_dir, meta_name)
    
    def data_generator():
    	"""产生 Fake 训练数据"""
        dataset = tf.data.Dataset.from_tensor_slices((np.random.randn(sample_num, emb_dim),\
                            np.random.randn(sample_num)))
        dataset = dataset.repeat(epochs).batch(batch_size)
        iterator = tf.data.make_one_shot_iterator(dataset)
        feature, label = iterator.get_next()
        return feature, label
    
    def model(feature, params=[10, 5, 1]):
    	"""定义模型, 3层DNN"""
        fc1 = tf.layers.dense(feature, params[0], activation=tf.nn.relu, name='fc1')
        fc2 = tf.layers.dense(fc1, params[1], activation=tf.nn.relu, name='fc2')
        fc3 = tf.layers.dense(fc2, params[2], activation=tf.nn.sigmoid, name='fc3')
        out = tf.identity(fc3, name='output')
        return out
    
    def train():
        feature, label = data_generator()
        output = model(feature)
        loss = tf.reduce_mean(tf.square(output - label))
        train_op = tf.train.AdamOptimizer(learning_rate=0.1, name='Adam').minimize(loss)
        saver = tf.train.Saver()
    
        if exists(checkpoint_dir):
            os.system('rm -rf {}'.format(checkpoint_dir))
    
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            try:
                local_step = 0
                save_freq = 2
                while True:
                    local_step += 1
                    _, loss_val = sess.run([train_op, loss])
                    if local_step % save_freq == 0:
                        saver.save(sess, saver_dir)
                    print('loss: {:.4f}'.format(loss_val))
            except tf.errors.OutOfRangeError:
                print("train end!")
    
    
    if __name__ == '__main__':
        train()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61

    运行 python train.py 会在当前目录下生成 checkpoint_dir 目录, 其组成如下:

    checkpoint_dir/
    |-- 0.data-00000-of-00001  ## 记录了网络参数值 
    |-- 0.index  ## 记录了网络参数名
    |-- 0.meta   ## 保存 MetaGraphDef, 该文件以 pb 格式记录了网络结构
    `-- checkpoint  ## 该文件记录了最新的 ckpt
    
    • 1
    • 2
    • 3
    • 4
    • 5

    加载 ckpt & 检查 graph 结构

    checkpoint 格式的模型需要在 Tensorflow 框架下进行加载. 比如编写 eval.py 进行 inference, 代码如下:

    #_*_ coding:utf-8 _*_
    ## eval.py
    import tensorflow as tf
    import os
    from os.path import join, exists
    import numpy as np
    
    emb_dim = 4
    checkpoint_dir = 'checkpoint_dir'
    meta_name = '0'
    saver_dir = join(checkpoint_dir, meta_name)
    meta_file = saver_dir + '.meta'
    model_file = tf.train.latest_checkpoint(checkpoint_dir)
    
    np.random.seed(123)
    test_data = np.random.randn(4, emb_dim) ## 生成测试数据
    
    def eval_graph():
        with tf.Session() as sess:
            saver = tf.train.import_meta_graph(meta_file)
            saver.restore(sess, model_file)
            output = sess.run(['output:0'], feed_dict={
                'IteratorGetNext:0': test_data
            })
            print('eval_graph:\n{}'.format(output))
    
    if __name__ == '__main__':
        eval_graph()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28

    在上面代码中, 注意到输入和输出节点名分别为 output 以及 IteratorGetNext. 对于输出节点, 由于在 train.pymodel() 函数中使用

    out = tf.identity(fc3, name='output')
    
    • 1

    对输出节点重新命名为 output, 因此输出节点的名字非常好确定. 但是输入节点的名字却不太好确定, 原因是训练时采用 tf.data API 来传入数据, 没有显式地对输入节点进行命名. 不过由于保存模型时网络结构都已经存放在 0.meta 文件中了, 因此可以通过解析该文件来查看网络的输入节点, 具体方法如下:

    #_*_ coding:utf-8 _*_
    ## check_graph.py
    import tensorflow as tf
    from tensorflow.python.framework import meta_graph
    from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef
    from google.protobuf import text_format
    
    import os
    from os.path import join, exists
    import numpy as np
    
    checkpoint_dir = 'checkpoint_dir'
    meta_name = '0'
    saver_dir = join(checkpoint_dir, meta_name)
    meta_file = saver_dir + '.meta'
    model_file = tf.train.latest_checkpoint(checkpoint_dir)
    
    def read_pb_meta(meta_file):
    	"""读取 pb 格式的 meta 文件"""
        meta_graph_def = meta_graph.read_meta_graph_file(meta_file)
        return meta_graph_def
    
    def read_txt_meta(txt_meta_file):
    	"""读取文本格式的 meta 文件"""
        meta_graph = MetaGraphDef()
        with open(txt_meta_file, 'rb') as f:
            text_format.Merge(f.read(), meta_graph)
        return meta_graph
    
    def read_pb_graph(graph_file):
    	"""读取 pb 格式的 graph_def 文件"""
        try:
            with tf.gfile.GFile(graph_file, 'rb') as pb:
                graph_def = tf.GraphDef()
                graph_def.ParseFromString(pb.read())
        except IOError as e:
            raise Exception("Parse '{}' Failed!".format(graph_file))
        return graph_def
    
    
    def check_graph_def(graph_def):
    	"""检查 graph_def 中的各节点"""
        with tf.Graph().as_default() as graph:
            tf.import_graph_def(
                graph_def,
                name=""
            )
            print('===> {}'.format(type(graph)))
            for op in graph.get_operations():
                print(op.name, op.values())  ## 打印网络结构
    
    def check_graph(graph_file):
    	"""检查 pb 格式的 graph_def 文件中的各节点"""
        graph_def = read_pb_graph(graph_file)
        check_graph_def(graph_def)
        
    
    if __name__ == '__main__':
        check_graph_def(read_pb_meta(meta_file).graph_def)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59

    输出结果如下图所示, 可以发现距离网络参数 fc1/kernel 最近的节点是 IteratorGetNext, 因此输入节点的名字基本可以确认是它了.

    节点修改

    现在回到 “前言” 中提到的问题, 如果我希望使用自行创建的 tf.placeholder 节点作为 Graph 的输入节点, 而不是采用 IteratorGetNext, 应该如何实现. 一方面可以重新将 Tensorflow Graph 写一遍, 使用 tf.placeholder 作为输入; 另一方面其实可以考虑将 IteratorGetNet 节点用自定义的节点给替换掉, 这一步参考了博文 如何在建好TF图后修改图. 具体做法如下, 代码在 infer.py 中:

    #_*_ coding:utf-8 _*_
    ## infer.py
    import tensorflow as tf
    from tensorflow.python.framework import meta_graph
    import os
    from os.path import join, exists
    import numpy as np
    
    emb_dim = 4
    checkpoint_dir = 'checkpoint_dir'
    meta_name = '0'
    saver_dir = join(checkpoint_dir, meta_name)
    meta_file = saver_dir + '.meta'
    model_file = tf.train.latest_checkpoint(checkpoint_dir)
    
    np.random.seed(123)
    test_data = np.random.randn(4, emb_dim)
    
    def read_pb_meta(meta_file):
        meta_graph_def = meta_graph.read_meta_graph_file(meta_file)
        return meta_graph_def
    
    def update_node(graph, src_node_name, tar_node):
        """
        @params:
            graph : tensorflow Graph object
            src_node_name : source node name to be modified
            tar_node : target node
        """
        input = graph.get_tensor_by_name('{}:0'.format(src_node_name))
        for op in input.consumers():
            idx_list = []
            for idx, item in enumerate(op.inputs):
                if src_node_name in item.name:
                    idx_list.append(idx)
            for idx in idx_list:
                op._update_input(idx, tar_node)
    
    def modify_graph():
        meta_graph_def = read_pb_meta(meta_file)
        with tf.Graph().as_default() as graph:
            tf.import_graph_def(meta_graph_def.graph_def, name="")
            input_ph = tf.placeholder(tf.float64, [None, emb_dim], name='input')
            update_node(graph, 'IteratorGetNext', input_ph)
    
        with tf.Session(graph=graph) as sess:
            saver = tf.train.import_meta_graph(meta_file)
            saver.restore(sess, model_file)
            output = sess.run(['output:0'], feed_dict={
                'input:0': test_data
            })
            print('modify_graph:\n{}'.format(output))
    
    
    if __name__ == '__main__':
        modify_graph()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56

    该文件定义了函数 update_node 来实现对 graph 中节点的替换, 函数如下:

    def update_node(graph, src_node_name, tar_node):
        """
        @params:
            graph : tensorflow Graph object
            src_node_name : source node name to be modified
            tar_node : target node
        """
        input = graph.get_tensor_by_name('{}:0'.format(src_node_name))
        for op in input.consumers():
            idx_list = []
            for idx, item in enumerate(op.inputs):
                if src_node_name in item.name:
                    idx_list.append(idx)
            for idx in idx_list:
                op._update_input(idx, tar_node)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    其中 src_node_name 表示要被替换掉的节点名字, 比如希望替换 IteratorGetNext. 通过该名字在 graph 中找到对应的节点 input, 然后调用 input.consumers() 找到使用该节点的 op, 再通过更新 op 的输入 (op.inputs) 来实现对节点的替换. 由于替换的方法 op._update_input 需要使用索引 idx, 因此用 idx_list 来记录要替换节点的索引.

    frozen_graph 格式

    前面介绍的 checkpoint 格式将网络结构和参数分开保存, 而 frozen_graph 格式则会将网络参数以 Const 节点的形式写入到 GraphDef, 并保存到统一的 protobuf 文件中, 由于 protobuf 是跨语言、跨平台序列化数据协议, 因此还可以用 C++/Java/Python 等对模型进行加载.

    下面写了个简单的将 ckpt 转换为 frozen_graph 的例子 frozen_graph.py, 代码如下:

    #_*_ coding:utf-8 _*_
    ## frozen_graph.py
    import tensorflow as tf
    from tensorflow.python.framework import meta_graph
    from tensorflow.python.framework import dtypes
    from tensorflow.python.tools import optimize_for_inference_lib
    import os
    from os.path import join, exists
    import numpy as np
    
    emb_dim = 4
    checkpoint_dir = 'checkpoint_dir'
    meta_name = '0'
    saver_dir = join(checkpoint_dir, meta_name)
    meta_file = saver_dir + '.meta'
    model_file = tf.train.latest_checkpoint(checkpoint_dir)
    
    np.random.seed(123)
    test_data = np.random.randn(4, emb_dim)
    
    def read_pb_meta(meta_file):
        meta_graph_def = meta_graph.read_meta_graph_file(meta_file)
        return meta_graph_def
    
    def update_node(graph, src_node_name, tar_node):
        """
        @params:
            graph : tensorflow Graph object
            src_node_name : source node name to be modified
            tar_node : target node
        """
        input = graph.get_tensor_by_name('{}:0'.format(src_node_name))
        for op in input.consumers():
            idx_list = []
            for idx, item in enumerate(op.inputs):
                if src_node_name in item.name:
                    idx_list.append(idx)
            for idx in idx_list:
                op._update_input(idx, tar_node)
    
    def check_graph_def(graph_def):
        with tf.Graph().as_default() as graph:
            tf.import_graph_def(
                graph_def,
                name=""
            )
            print('===> {}'.format(type(graph)))
            for op in graph.get_operations():
                print(op.name, op.values())  ## 打印网络结构
    
    def write_frozen_graph():
        meta_graph_def = read_pb_meta(meta_file)
        with tf.Graph().as_default() as graph:
            tf.import_graph_def(meta_graph_def.graph_def, name="")
            input_ph = tf.placeholder(tf.float64, [None, emb_dim], name='input')
            update_node(graph, 'IteratorGetNext', input_ph)
    
        with tf.Session(graph=graph) as sess:
            saver = tf.train.import_meta_graph(meta_file)
            saver.restore(sess, model_file)
    
            input_node_names = ['input']
            ##placeholder_type_enum = [dtypes.float64.as_datatype_enum]
            placeholder_type_enum = [input_ph.dtype.as_datatype_enum]
            output_node_names = ['output']
            ## 对 graph 进行优化, 把和 inference 无关的节点给删除, 比如 Saver 有关的节点
            graph_def = optimize_for_inference_lib.optimize_for_inference(
                graph.as_graph_def(), input_node_names, output_node_names, placeholder_type_enum
            )
            check_graph_def(graph_def)
            ## 将 ckpt 转换为 frozen_graph, 网络权重和结构写入统一 pb 文件中, 参数以 Const 的形式保存
            frozen_graph = tf.graph_util.convert_variables_to_constants(sess, 
                graph_def, output_node_names)
            out_graph_path = os.path.join('.', "frozen_model.pb")
            with tf.gfile.GFile(out_graph_path, "wb") as f:
                f.write(frozen_graph.SerializeToString())
    
    def read_frozen_graph():
        with tf.Graph().as_default() as graph:
            graph_def = tf.GraphDef()
            with open("frozen_model.pb", 'rb') as f:
                graph_def.ParseFromString(f.read())
                tf.import_graph_def(graph_def, name='')
            
            # print(graph_def)
        
        with tf.Session(graph=graph) as sess:
            output = sess.run(['output:0'], feed_dict={
                'input:0': test_data
            })
            print('frozen_graph:\n{}'.format(output))   
    
    if __name__ == '__main__':
        write_frozen_graph()
        read_frozen_graph()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95

    其中 write_frozen_graph() 中调用 optimize_for_inference_lib.optimize_for_inference 对 Graph 节点进行优化, 将在下一节进行介绍. 此外还调用 tf.graph_util.convert_variables_to_constants 将 ckpt 转换为 frozen_graph, 参数以 Const 的形式保存:

    Serving 图优化

    在上一节生成 frozen_graph 时, 调用了 optimize_for_inference_lib.optimize_for_inference 对 Graph 节点进行优化, 本节简要对其进行说明. 在调用该函数前如果打印从 checkpoint 中加载的 graph 时, 会发现结构中包含很多在训练时需要但在线 Serving 时并不需要的 Op, 如优化算法 Adam, 模型保存 Saver, 梯度 gradients 等等, 如下图:

    optimize_for_inference_lib.optimize_for_inference 函数的一个主要工作就是将 graph 在 Serving 时无用的 Op 给去除.

    该函数定义在 https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/optimize_for_inference_lib.py,

    def optimize_for_inference(input_graph_def, input_node_names, output_node_names,
                               placeholder_type_enum, toco_compatible=False):
      ## ..... 显示核心代码
      optimized_graph_def = strip_unused_lib.strip_unused(
          optimized_graph_def, input_node_names, output_node_names,
          placeholder_type_enum)
      optimized_graph_def = graph_util.remove_training_nodes(
          optimized_graph_def, output_node_names)
      ## .... 
      return optimized_graph_def
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    其中 strip_unused_lib.strip_unused 定义如下:

    def strip_unused(input_graph_def, input_node_names, output_node_names,
                     placeholder_type_enum):
      """Removes unused nodes from a GraphDef.
      Args:
        input_graph_def: A graph with nodes we want to prune.
        input_node_names: A list of the nodes we use as inputs.
        output_node_names: A list of the output nodes.
        placeholder_type_enum: The AttrValue enum for the placeholder data type, or
            a list that specifies one value per input node name.
      Returns:
        A `GraphDef` with all unnecessary ops removed.
      Raises:
        ValueError: If any element in `input_node_names` refers to a tensor instead
          of an operation.
        KeyError: If any element in `input_node_names` is not found in the graph.
      """
      for name in input_node_names:
        if ":" in name:
          raise ValueError(f"Name '{name}' appears to refer to a Tensor, not an "
                           "Operation.")
    
      # Here we replace the nodes we're going to override as inputs with
      # placeholders so that any unused nodes that are inputs to them are
      # automatically stripped out by extract_sub_graph().
      not_found = {name for name in input_node_names}
      inputs_replaced_graph_def = graph_pb2.GraphDef()
      for node in input_graph_def.node:
        if node.name in input_node_names:
          not_found.remove(node.name)
          placeholder_node = node_def_pb2.NodeDef()
          placeholder_node.op = "Placeholder"
          placeholder_node.name = node.name
          if isinstance(placeholder_type_enum, list):
            input_node_index = input_node_names.index(node.name)
            placeholder_node.attr["dtype"].CopyFrom(
                attr_value_pb2.AttrValue(type=placeholder_type_enum[
                    input_node_index]))
          else:
            placeholder_node.attr["dtype"].CopyFrom(
                attr_value_pb2.AttrValue(type=placeholder_type_enum))
          if "_output_shapes" in node.attr:
            placeholder_node.attr["_output_shapes"].CopyFrom(node.attr[
                "_output_shapes"])
          if "shape" in node.attr:
            placeholder_node.attr["shape"].CopyFrom(node.attr["shape"])
          inputs_replaced_graph_def.node.extend([placeholder_node])
        else:
          inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])
    
      if not_found:
        raise KeyError(f"The following input nodes were not found: {not_found}.")
    
      output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
                                                      output_node_names)
      return output_graph_def
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55

    该代码需要传入 graph_def, 输入节点名字 input_node_names 以及输出节点名字 output_node_names, 前面一大段代码是为了用 Placeholder 替换原本的输入节点, 算是将整个 Graph 重新写了一遍. 之后在 graph_util.extract_sub_graph 函数中, 利用 BFS 算法保留 Serving 时需要的节点, 而将不需要的节点全部给去除:

    def extract_sub_graph(graph_def, dest_nodes):
      """Extract the subgraph that can reach any of the nodes in 'dest_nodes'.
      Args:
        graph_def: A graph_pb2.GraphDef proto.
        dest_nodes: An iterable of strings specifying the destination node names.
      Returns:
        The GraphDef of the sub-graph.
      Raises:
        TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto.
      """
    
     ## ... BFS 遍历 Serving 时用到的节点
    
      nodes_to_keep = _bfs_for_reachable_nodes(dest_nodes, name_to_input_name)
    
      nodes_to_keep_list = sorted(
          list(nodes_to_keep), key=lambda n: name_to_seq_num[n])
      # Now construct the output GraphDef
      out = graph_pb2.GraphDef()
      for n in nodes_to_keep_list:
        out.node.extend([copy.deepcopy(name_to_node[n])])
      out.library.CopyFrom(graph_def.library)
      out.versions.CopyFrom(graph_def.versions)
      
      return out
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25

    其中 BFS 函数定义如下:

    def _node_name(n):
      if n.startswith("^"):
        return n[1:]
      else:
        return n.split(":")[0]
    
    def _extract_graph_summary(graph_def):
      """Extracts useful information from the graph and returns them."""
      name_to_input_name = {}  # Keyed by the dest node name.
      name_to_node = {}  # Keyed by node name.
    
      # Keeps track of node sequences. It is important to still output the
      # operations in the original order.
      name_to_seq_num = {}  # Keyed by node name.
      seq = 0
      for node in graph_def.node:
        n = _node_name(node.name)
        name_to_node[n] = node
        name_to_input_name[n] = [_node_name(x) for x in node.input]
        ### ....
        name_to_seq_num[n] = seq
        seq += 1
      return name_to_input_name, name_to_node, name_to_seq_num
    
    def _bfs_for_reachable_nodes(target_nodes, name_to_input_name):
      """Breadth first search for reachable nodes from target nodes."""
      nodes_to_keep = set()
      # Breadth first search to find all the nodes that we should keep.
      next_to_visit = list(target_nodes)
      while next_to_visit:
        node = next_to_visit[0]
        del next_to_visit[0]
        if node in nodes_to_keep:
          # Already visited this node.
          continue
        nodes_to_keep.add(node)
        if node in name_to_input_name:
          next_to_visit += name_to_input_name[node]
      return nodes_to_keep
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39

    之所以把这几段代码单独拎出来, 可以在合适的时候拿出来对 graph_def 进行调试, 打印中间结果. 经过 optimize_for_inference_lib.optimize_for_inference 的处理后, graph 更为简洁轻量, 打印其中的 Op 得到:

    可以看到, 训练中会用到的 Adam, Saver 等节点全部被移除了, 整个 graph 变得异常干净整洁.

    总结

    写文章就是, 一鼓作气, 再而衰, 三而竭, 再一鼓作气.
    我要去玩耍了.

  • 相关阅读:
    【实用工具】frp实现内网穿透
    提升ChatGPT答案质量和准确性的方法Prompt engineering
    Linux:安装IDEA开发工具
    【附源码】计算机毕业设计SSM社区便捷管理系统
    Alibaba/IOC-golang 正式开源 ——打造服务于go开发者的IOC框架
    Python编程:使用PIL进行JPEG图像压缩的简易教程
    linux centos出现No space left on device解决方案
    C#开发的OpenRA游戏之属性RenderSprites(8)
    记录一次LiteFlow项目实战
    Object中的方法
  • 原文地址:https://blog.csdn.net/Eric_1993/article/details/126197197