使用tensorflow将TF模型转化成PyTorch模型
获取如下三个文件:
这里假设已经安装过PyTorch了。
开始转化TF2模型位PyTorch模型:
# 安装依赖
pip3 install tensorflow transformers
export BERT_BASE_DIR=~/Downloads/nlp_bert/multi_cased_L-12_H-768_A-12
transformers-cli convert --model_type bert \
--tf_checkpoint $BERT_BASE_DIR/bert_model.ckpt \
--config $BERT_BASE_DIR/bert_config.json \
--pytorch_dump_output $BERT_BASE_DIR/pytorch_model.bin
这里的pytorch_model.bin就是TF2的已经训练好的模型转化过来的PyTorch模型。