解決tensorflow模型壓縮的問題_踩坑無數,總算搞定
1.安裝bazel,從github上下載linux版的.sh文件,然後安裝
2.從GitHub上下載最新的TensorFlow源碼
3.進入TensorFlow源碼文件夾,輸入命令
bazel build tensorflow/tools/graph_transforms:transform_graph
這裡會遇到各種坑,比如
ERROR: /opt/tf/tensorflow-master/tensorflow/core/kernels/BUILD:3044:1: C++ compilation of rule ‘//tensorflow/core/kernels:matrix_square_root_op’ failed (Exit 4)
gcc: internal compiler error: Killed (program cc1plus)
這個錯誤是cpu負荷太大,需要加行代碼
# 生成swap鏡像文件 sudo dd if=/dev/zero of=/mnt/512Mb.swap bs=1M count=512 # 對該鏡像文件格式化 sudo mkswap /mnt/512Mb.swap # 掛載該鏡像文件 sudo swapon /mnt/512Mb.swap
又或者這個@aws Error downloading
我看csdn有的博主解決方法是去臨時文件夾刪掉文件重新下載,但是我這邊發現沒用,我這邊的解決方法是運行bazel前先輸入一條命令:
sed -i '\@https://github.com/aws/aws-sdk-cpp/archive/1.5.8.tar.gz@aws' tensorflow/workspace.bzl
命令裡的網址就是實際要下載的文件的地址,因為有的地址可能改瞭
到這裡編譯bazel就完成瞭
4.編譯完瞭就可以模型壓縮瞭,也是一行代碼,in_graph為輸入模型路徑,outputs不動,out_graph為輸出模型路徑,transforms就填一個quantize_weights就可以瞭,這個就是把32bit轉成8bit的,也是此方法最有效的一步;我看有的博主還先編譯summary然後打印出輸入輸出結點,之後再輸入一大堆參數,還刪除一些結點啥的,我這邊都試瞭,最終也並沒有更縮減模型大小,所以就這樣就可以瞭。
bazel-bin/tensorflow/tools/graph_transforms/transform_graph --in_graph=../model/ctpn.pb --outputs='output_node_name' --out_graph=../model/quantized_ctpn.pb --transforms='quantize_weights'
最終從68m縮減到17m,75%的縮減比例,實測效果基本沒啥差別,這方法還是很管用的。
補充:模型壓縮一二三之tensorflow查看ckpt模型裡的參數和數值
查看ckpt模型參數和數值
import os from tensorflow.python import pywrap_tensorflow checkpoint_path = os.path.join("<你的模型的目錄>", "./model.ckpt-11000") # Read data from checkpoint file reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) var_to_shape_map = reader.get_variable_to_shape_map() # Print tensor name and values for key in var_to_shape_map: print("tensor_name: ", key) print(reader.get_tensor(key))
註意:
1、”<你的模型目錄>“是指你的meta、ckpt這些模型存儲的路徑。
比如路徑”/models/model.ckpt-11000.meta”這種,那麼”<你的模型目錄>“就是”/models”
2、當目錄下有多個ckpt時,取最新的model名字到ckpt-<最大數字>就可以瞭,後面不用瞭。
以上為個人經驗,希望能給大傢一個參考,也希望大傢多多支持WalkonNet。如有錯誤或未考慮完全的地方,望不吝賜教。
推薦閱讀:
- 基於tensorflow權重文件的解讀
- pytorch_pretrained_bert如何將tensorflow模型轉化為pytorch模型
- 詳解TensorFlow訓練網絡兩種方式
- 淺談tensorflow語義分割api的使用(deeplab訓練cityscapes)
- pytorch 把圖片數據轉化成tensor的操作