保存模型
使用TensorFlow的saver()类先实例化一个saver对象,然后在session中通过saver的save方法将模型保存起来。代码示例如下:
- #初始化所有变量
- init = tf.global_variable_initializer()
-
- #定义saver和保存路径
- saver = tf.train.Saver()
- saverdir = "save_path"
-
- #启动Session
- with tf.Session() as sess:
- sess.run(init)
- #使用saver的save方法保存
- saver.save(sess,saverdir + "file_name")
其中,filename如果不存在,程序会自动创建。
打印模型中的内容
使用inspect_checkpoint包中的print_tensors_in_checkpoint_file方法将模型中的具体内容打印出来。代码示例如下:
- import tensorflow.compat.v1 as tf
- tf.disable_v2_behavior()
- form tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
-
- saverdir = "log/"
- print_tensors_in_checkpoint_file(savedir + "linearmodel.cpkt",None,True)
保存模型的其他方法
使用saver()类保存模型时,可以在函数中放入参数来实现更高级的功能,如指定存储变量名字与变量的对应关系。代码示例如下:
- W = tf.Variable(1.0,name = "weight")
- b = tf.Variable(2.0,name = "bias")
-
- saver = tf.train.Saver({'weight':W,'bias':b})
- with tf.Session() as sess:
- tf.global_variables_initializer().run()
- saver.save(sess,savedir + "linearmodel.cpkt")
- print_tensors_in_checkpoint_file(savedir + "linearmodel.cpkt",None,True)
载入模型
通过调用saver的restore()函数,从指定的路径找到模型文件,并覆盖到相关参数中。代码示例如下:
- #初始化所有变量
- init = tf.global_variable_initializer()
-
- #定义saver和保存路径
- saver = tf.train.Saver()
- saverdir = "save_path"
-
- #启动Session
- with tf.Session() as sess:
- sess.run(init)
- #使用saver的restore方法载入模型
- print("x=0.2,z=",sess.run(z,feed_dict = {X:0.2}))