• 多个checkpoint 的参数进行平均


    source_model 路径下 存在 以下几个checkpoint
    model_checkpoint_path: "model.ckpt-457157707"
    all_model_checkpoint_paths: "model.ckpt-456023526" ,all_model_checkpoint_paths: "model.ckpt-456332667" ,all_model_checkpoint_paths: "model.ckpt-456332668",all_model_checkpoint_paths: "model.ckpt-456832684" ,all_model_checkpoint_paths: "model.ckpt-457157707"

    现在将这些ckpt的参数进行平均 合并成一个model.ckpt-457157708 

    1. import tensorflow as tf
    2. import numpy as np
    3. # 获取所有的checkpoint文件
    4. ckpt_files = ["model.ckpt-456023526", "model.ckpt-456332667", "model.ckpt-456332668", "model.ckpt-456832684", "model.ckpt-457157707"]
    5. ckpt_files = [os.path.join("source_model", ckpt_file) for ckpt_file in ckpt_files]
    6. # 用于存储所有模型的参数
    7. all_model_vars = {}
    8. for ckpt_file in ckpt_files:
    9. reader = tf.train.NewCheckpointReader(ckpt_file)
    10. model_vars = reader.get_variable_to_shape_map()
    11. for var in model_vars:
    12. if var not in all_model_vars:
    13. all_model_vars[var] = []
    14. all_model_vars[var].append(reader.get_tensor(var))
    15. # 计算每个参数的平均值
    16. average_vars = {var: np.mean(values, axis=0) for var, values in all_model_vars.items()}
    17. # 创建一个新的checkpoint文件,并将平均后的参数保存到新的.data文件中
    18. with tf.Session() as sess:
    19. for var_name, var_value in average_vars.items():
    20. var = tf.get_variable(var_name, initializer=var_value)
    21. sess.run(var.initializer)
    22. saver = tf.train.Saver()
    23. saver.save(sess, "source_model/model.ckpt-457157708")
  • 相关阅读:
    Kubernetes(k8s)的Volume数据存储配置储存类型ConfigMap和Secret的使用
    高清图片、视频素材免费下载
    linux 配置安装node.js
    Java中的Collection
    SpringMVC简介
    Day18.2:对象创建的内存分析图解
    主成分分析(PCA)介绍
    Python入门(二)
    第一百五十三回 如何实现滑动窗口
    数据库查询优化:主从读写分离及常见问题
  • 原文地址:https://blog.csdn.net/qq_32806793/article/details/132878772