• TensorFlow入门(七、检查点)


    保存检查点

    在实际的模型训练中,TensorFlow难免会出现中断的情况,使得到的中间参数丢失,因此需要在模型训练过程中及时将模型保存下来。并将这种在训练中保存模型的操作,称为保存检查点

    通过设置saver的另一个参数max_to_keep,指定生成检查点文件的个数,代码示例如下:

    saver = tf.train.Saver(max_to_keep = 1)

    在保存模型时可以传入迭代次数,如:

    saver.save(sess,saverdir + "linearmodel.cpkt",global_step = epoch)

    载入时同样也要指定迭代次数,如:

    saver.restore(sess2,saverdir + "linearmodel.cpkt-" + str(final_epoch))

    快速获取检查点文件

    快速获取检查点文件有两种方法:

    ①使用get_checkpoint_state函数,传入检查点文件路径作为参数,从而找到检查点文件。该函数返回的是checkpoint 文件CheckpointState proto类型的内容,其有model_checkpoint_path和all_model_checkpoint_paths两个属性。其中model_checkpoint_path保存了最新的检查点文件的文件名,all_model_checkpoint_paths则是未被删除的所有保存下来的检查点文件的文件名。

    1. final_epoch = 18
    2. ckpt = tf.train.get_checkpoint_state("log/")
    3. with tf.Session() as sess2:
    4. sess2.run(init)
    5. if ckpt and ckpt.model_checkpoint_path:
    6. saver.restore(sess2,ckpt.model_checkpoint_path)
    7. print("x=0.5,z=",sess2.run(z,feed_dict = {X:0.5}))

    ②使用latest_checkpoint()函数查找最新保存的检查点文件,该方法是速度最快的。

    1. ckpt = tf.train.latest_checkpoint("log/")
    2. with tf.Session() as sess2:
    3. sess2.run(init)
    4. if ckpt != None:
    5. saver.restore(sess2,ckpt)
    6. print("x=0.5,z=",sess2.run(z,feed_dict = {X:0.5}))

    按照训练时间保存检查点

    使用MonitoredTrainingSession()函数,该函数可以直接实现保存及载入检查点模型的文件,并且可以通过设置save_checkpoint_secs参数的具体秒数,来设置每训练多久保存一次检查点。

    1. #使用MonitoredTrainingSession()之前,必须定义global_step变量
    2. global_step = tf.train.get_or_create_global_step()
    3. checkpoint_step = tf.assign_add(global_step,1)
    4. #定义检查点保存路径
    5. saverdir = "log/checkpoints"
    6. #启动session
    7. with tf.train.MonitoredTrainingSession(checkpoint_dir=saverdir,save_checkpoint_secs=1) as sess:
    8. print("global_step=",sess.run([global_step]))
    9. #使用死循环,session不结束时就不结束
    10. while not sess.should_stop():
    11. i = sess.run(checkpoint_step)
    12. print("i=",i)

  • 相关阅读:
    【多线程】线程管理
    C++ SLT中的容器学习与函数谓词
    【GESP考级C++】1级样题 闰年统计
    【图像处理】图像基础滤波处理:盒式滤波、均值滤波、高斯滤波、中值滤波、双边滤波
    spicy(一)基本定义
    【Nginx】Windows10 平台下配置Nginx服务实现负载均衡
    DMBOK知识梳理for CDGA/CDGP——第一章数据管理(附常考知识点)
    C#控制台程序中使用log4.net来输出日志
    2024年抖店的市场已经饱和,小白不适合入局了?真实现状如下
    Atcoder ABC159
  • 原文地址:https://blog.csdn.net/Victor_Li_/article/details/133321367