source_model 路径下 存在 以下几个checkpoint
model_checkpoint_path: "model.ckpt-457157707"
all_model_checkpoint_paths: "model.ckpt-456023526" ,all_model_checkpoint_paths: "model.ckpt-456332667" ,all_model_checkpoint_paths: "model.ckpt-456332668",all_model_checkpoint_paths: "model.ckpt-456832684" ,all_model_checkpoint_paths: "model.ckpt-457157707"
现在将这些ckpt的参数进行平均 合并成一个model.ckpt-457157708
- import tensorflow as tf
- import numpy as np
-
- # 获取所有的checkpoint文件
- ckpt_files = ["model.ckpt-456023526", "model.ckpt-456332667", "model.ckpt-456332668", "model.ckpt-456832684", "model.ckpt-457157707"]
- ckpt_files = [os.path.join("source_model", ckpt_file) for ckpt_file in ckpt_files]
-
- # 用于存储所有模型的参数
- all_model_vars = {}
-
- for ckpt_file in ckpt_files:
- reader = tf.train.NewCheckpointReader(ckpt_file)
- model_vars = reader.get_variable_to_shape_map()
- for var in model_vars:
- if var not in all_model_vars:
- all_model_vars[var] = []
- all_model_vars[var].append(reader.get_tensor(var))
-
- # 计算每个参数的平均值
- average_vars = {var: np.mean(values, axis=0) for var, values in all_model_vars.items()}
-
- # 创建一个新的checkpoint文件,并将平均后的参数保存到新的.data文件中
- with tf.Session() as sess:
- for var_name, var_value in average_vars.items():
- var = tf.get_variable(var_name, initializer=var_value)
- sess.run(var.initializer)
-
- saver = tf.train.Saver()
- saver.save(sess, "source_model/model.ckpt-457157708")