pytorch顯存一直變大的解決方案
在代碼中添加以下兩行可以解決:
torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True
補充:pytorch訓練過程顯存一直增加的問題
之前遇到瞭爆顯存的問題,卡瞭很久,試瞭很多方法,總算解決瞭。
總結下自己試過的幾種方法:
**1. 使用torch.cuda.empty_cache()
在每一個訓練epoch後都添加這一行代碼,可以讓訓練從較低顯存的地方開始,但並不適用爆顯存的問題,隨著epoch的增加,最大顯存占用仍然會提示out of memory 。
2.使用torch.backends.cudnn.enabled = True 和 torch.backends.cudnn.benchmark = True
原理不太清楚,用法和1一樣。但是幾乎沒有效果,直接pass。
3.最重要的:查看自己的forward函數是否存在泄露。
常需要在forward函數裡調用其他子函數,這時候要特別註意:
input盡量不要寫在for循環裡面!!!
子函數裡如果有append()等函數,一定少用,能不用就不用!!!
子函數list一定少用,能不用就不用!!!
總之,子函數一般也不會太復雜,直接寫出來,別各種for,嵌套,變量。!!!
補充:Pytorch顯存不斷增長問題的解決思路
這個問題,我先後遇到過兩次,每次都異常艱辛的解決瞭。
在網上,關於這個問題,你可以找到各種看似不同的解決方案,但是都沒能解決我的問題。所以隻能自己摸索,在摸索的過程中,有瞭一個排查問題點的思路。
下面舉個例子說一下我的思路。
大體思路
其實思路很簡單,就是在代碼的運行階段輸出顯存占用量,觀察在哪一塊存在顯存劇烈增加或者顯存異常變化的情況。
但是在這個過程中要分級確認問題點,也即如果存在三個文件main.py、train.py、model.py。
在此種思路下,應該先在main.py中確定問題點,然後,從main.py中進入到train.py中,再次輸出顯存占用量,確定問題點在哪。
隨後,再從train.py中的問題點,進入到model.py中,再次確認。
如果還有更深層次的調用,可以繼續追溯下去。
具體例子
main.py
def train(model,epochs,data): for e in range(epochs): print("1:{}".format(torch.cuda.memory_allocated(0))) train_epoch(model,data) print("2:{}".format(torch.cuda.memory_allocated(0))) eval(model,data) print("3:{}".format(torch.cuda.memory_allocated(0)))
假設1與2之間顯存增加極為劇烈,說明問題出在train_epoch中,進一步進入到train.py中。
train.py
def train_epoch(model,data): model.train() optim=torch.optimizer() for batch_data in data: print("1:{}".format(torch.cuda.memory_allocated(0))) output=model(batch_data) print("2:{}".format(torch.cuda.memory_allocated(0))) loss=loss(output,data.target) print("3:{}".format(torch.cuda.memory_allocated(0))) optim.zero_grad() print("4:{}".format(torch.cuda.memory_allocated(0))) loss.backward() print("5:{}".format(torch.cuda.memory_allocated(0))) utils.func(model) print("6:{}".format(torch.cuda.memory_allocated(0)))
如果在1,2之間,5,6之間同時出現顯存增加異常的情況。此時需要使用控制變量法,例如我們先讓5,6之間的代碼失效,然後運行,觀察是否仍然存在顯存爆炸。如果沒有,說明問題就出在5,6之間下一級的代碼中。進入到下一級代碼,進行調試:
utils.py
def func(model): print("1:{}".format(torch.cuda.memory_allocated(0))) a=f1(model) print("2:{}".format(torch.cuda.memory_allocated(0))) b=f2(a) print("3:{}".format(torch.cuda.memory_allocated(0))) c=f3(b) print("4:{}".format(torch.cuda.memory_allocated(0))) d=f4(c) print("5:{}".format(torch.cuda.memory_allocated(0)))
此時我們再展示另一種調試思路,先註釋第5行之後的代碼,觀察顯存是否存在先訓爆炸,如果沒有,則註釋掉第7行之後的,直至確定哪一行的代碼出現導致瞭顯存爆炸。假設第9行起作用後,代碼出現顯存爆炸,說明問題出在第九行,顯存爆炸的問題鎖定。
幾種導致顯存爆炸的情況
pytorch的hook機制可能導致,顯存爆炸,hook函數取出某一層的輸入輸出跟權重後,不可進行存儲,修改等操作,這會造成hook不能回收,進而導致取出的輸入輸出權重都可能不被pytorch回收,所以模型的負擔越來也大,最終導致顯存爆炸。
這種情況是我第二次遇到顯存爆炸查出來的,非常讓人匪夷所思。在如下代碼中,p.sub_(torch.mm(k, torch.t(k)) / (alpha + torch.mm(r, k))),導致瞭顯存爆炸,這個問題點就是通過上面的方法確定的。
這個P是一個矩陣,在使用p.sub_的方式更新P的時候,導致瞭顯存爆炸。
將這行代碼修改為p=p-(torch.mm(k, torch.t(k)) / (alpha + torch.mm(r, k))),顯存爆炸的問題解決。
def pro_weight(p, x, w, alpha=1.0, cnn=True, stride=1): if cnn: _, _, H, W = x.shape F, _, HH, WW = w.shape S = stride # stride Ho = int(1 + (H - HH) / S) Wo = int(1 + (W - WW) / S) for i in range(Ho): for j in range(Wo): # N*C*HH*WW, C*HH*WW = N*C*HH*WW, sum -> N*1 r = x[:, :, i * S: i * S + HH, j * S: j * S + WW].contiguous().view(1, -1) # r = r[:, range(r.shape[1] - 1, -1, -1)] k = torch.mm(p, torch.t(r)) p.sub_(torch.mm(k, torch.t(k)) / (alpha + torch.mm(r, k))) w.grad.data = torch.mm(w.grad.data.view(F, -1), torch.t(p.data)).view_as(w) else: r = x k = torch.mm(p, torch.t(r)) p.sub_(torch.mm(k, torch.t(k)) / (alpha + torch.mm(r, k))) w.grad.data = torch.mm(w.grad.data, torch.t(p.data))
以上為個人經驗,希望能給大傢一個參考,也希望大傢多多支持WalkonNet。如有錯誤或未考慮完全的地方,望不吝賜教。
推薦閱讀:
- None Found