pytorch_pretrained_bert如何將tensorflow模型轉化為pytorch模型
pytorch_pretrained_bert將tensorflow模型轉化為pytorch模型
BERT倉庫裡的模型是TensorFlow版本的,需要進行相應的轉換才能在pytorch中使用
在Google BERT倉庫裡下載需要的模型,這裡使用的是中文預訓練模型(chinese_L-12_H-768_A_12)
下載chinese_L-12_H-768_A-12.zip後解壓,裡面有5個文件
chinese_L-12_H-768_A-12.zip後解壓,裡面有5個文件
bert_config.json
bert_model.ckpt.data-00000-of-00001
bert_model.ckpt.index
bert_model.ckpt.meta
vocab.txt
使用bert倉庫裡的convert_bert_original_tf_checkpoint_to_pytorch.py將此模型轉化為pytorch版本的,這裡我的文件夾位置為:D:\Work\BISHE\BERT-Dureader\data\chinese_L-12_H-768_A-12,替換為自己的即可
python convert_tf_checkpoint_to_pytorch.py –tf_checkpoint_path D:\Work\BISHE\BERT-Dureader\data\chinese_L-12_H-768_A-12\bert_model.ckpt –bert_config_file D:\Work\BISHE\BERT-Dureader\data\chinese_L-12_H-768_A-12\bert_config.json –pytorch_dump_path D:\Work\BISHE\BERT-Dureader\data\chinese_L-12_H-768_A-12\pytorch_model.bin
註:這裡讓我疑惑的是模型有5個文件,為什麼轉化的時候使用的是bert_model.ckpt,而且這個文件也不存在呀,是我對TensorFlow的模型不太熟悉,查閱資料之後將5個文件的作用說明如下:
$ tree chinese_L-12_H-768_A-12/ chinese_L-12_H-768_A-12/ ├── bert_config.json <- 模型配置文件 ├── bert_model.ckpt.data-00000-of-00001 <- 保存斷點文件列表,可以用來迅速查找最近一次的斷點文件 ├── bert_model.ckpt.index <- 為數據文件提供索引,存儲的核心內容是以tensor name為鍵以BundleEntry為值的表格entries,BundleEntry主要內容是權值的類型、形狀、偏移、校驗和等信息。 ├── bert_model.ckpt.meta <- 是MetaGraphDef序列化的二進制文件,保存瞭網絡結構相關的數據,包括graph_def和saver_def等 └── vocab.txt <- 模型詞匯表文件 0 directories, 5 files
在調用模型時使用chinese_L-12_H-768_A-12\bert_model.ckpt即可。
TensorFlow 讀取ckpt文件中的tensor,將ckpt模型轉為pytorch模型
想用MobileNet V1訓練自己的數據,發現pytorch沒有MobileNet V1的預訓練權重,隻好先下載TensorFlow的預訓練權重,再轉成pytorch模型。
讀取ckpt中的Tensor名稱以及Tensor值
TensorFlow的MobileNet V1預訓練權重文件如下:
解壓完文件後,發現沒有.ckpt文件,文件名隻需‘./my_model/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224.ckpt’這樣寫就行。
寫一半發現Tensor名稱好難對應起來。希望能給大傢一個參考,也希望大傢多多支持WalkonNet
推薦閱讀:
- 解決tensorflow模型壓縮的問題_踩坑無數,總算搞定
- 基於tensorflow權重文件的解讀
- Pytorch BertModel的使用說明
- 關於Pytorch中模型的保存與遷移問題
- 淺談tensorflow與pytorch的相互轉換