pytorch 使用半精度模型部署的操作
背景
pytorch作為深度學習的計算框架正得到越來越多的應用.
我們除瞭在模型訓練階段應用外,最近也把pytorch應用在瞭部署上.
在部署時,為瞭減少計算量,可以考慮使用16位浮點模型,而訓練時涉及到梯度計算,需要使用32位浮點,這種精度的不一致經過測試,模型性能下降有限,可以接受.
但是推斷時計算量可以降低一半,同等計算資源下,並發度可提升近一倍
具體方法
在pytorch中,一般模型定義都繼承torch.nn.Moudle,torch.nn.Module基類的half()方法會把所有參數轉為16位浮點,所以在模型加載後,調用一下該方法即可達到模型切換的目的.接下來隻需要在推斷時把input的tensor切換為16位浮點即可
另外還有一個小的trick,在推理過程中模型輸出的tensor自然會成為16位浮點,如果需要新創建tensor,最好調用已有tensor的new_zeros,new_full等方法而不是torch.zeros和torch.full,前者可以自動繼承已有tensor的類型,這樣就不需要到處增加代碼判斷是使用16位還是32位瞭,隻需要針對input tensor切換.
補充:pytorch 使用amp.autocast半精度加速訓練
準備工作
pytorch 1.6+
如何使用autocast?
根據官方提供的方法,
答案就是autocast + GradScaler。
如何在PyTorch中使用自動混合精度?
答案:autocast + GradScaler。
1.autocast
正如前文所說,需要使用torch.cuda.amp模塊中的autocast 類。使用也是非常簡單的
from torch.cuda.amp import autocast as autocast # 創建model,默認是torch.FloatTensor model = Net().cuda() optimizer = optim.SGD(model.parameters(), ...) for input, target in data: optimizer.zero_grad() # 前向過程(model + loss)開啟 autocast with autocast(): output = model(input) loss = loss_fn(output, target) # 反向傳播在autocast上下文之外 loss.backward() optimizer.step()
2.GradScaler
GradScaler就是梯度scaler模塊,需要在訓練最開始之前實例化一個GradScaler對象。
因此PyTorch中經典的AMP使用方式如下:
from torch.cuda.amp import autocast as autocast # 創建model,默認是torch.FloatTensor model = Net().cuda() optimizer = optim.SGD(model.parameters(), ...) # 在訓練最開始之前實例化一個GradScaler對象 scaler = GradScaler() for epoch in epochs: for input, target in data: optimizer.zero_grad() # 前向過程(model + loss)開啟 autocast with autocast(): output = model(input) loss = loss_fn(output, target) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()
3.nn.DataParallel
單卡訓練的話上面的代碼已經夠瞭,親測在2080ti上能減少至少1/3的顯存,至於速度。。。
要是想多卡跑的話僅僅這樣還不夠,會發現在forward裡面的每個結果都還是float32的,怎麼辦?
class Model(nn.Module): def __init__(self): super(Model, self).__init__() def forward(self, input_data_c1): with autocast(): # code return
隻要把forward裡面的代碼用autocast代碼塊方式運行就好啦!
自動進行autocast的操作
如下操作中tensor會被自動轉化為半精度浮點型的torch.HalfTensor:
1、matmul
2、addbmm
3、addmm
4、addmv
5、addr
6、baddbmm
7、bmm
8、chain_matmul
9、conv1d
10、conv2d
11、conv3d
12、conv_transpose1d
13、conv_transpose2d
14、conv_transpose3d
15、linear
16、matmul
17、mm
18、mv
19、prelu
那麼隻有這些操作才能半精度嗎?不是。其他操作比如rnn也可以進行半精度運行,但是需要自己手動,暫時沒有提供自動的轉換。
推薦閱讀:
- Pytorch 中的optimizer使用說明
- 使用Pytorch實現two-head(多輸出)模型的操作
- pytorch自定義不可導激活函數的操作
- pytorch模型的保存和加載、checkpoint操作
- pytorch教程實現mnist手寫數字識別代碼示例