pytorch DataLoader的num_workers參數與設置大小詳解
Q:在給Dataloader設置worker數量(num_worker)時,到底設置多少合適?這個worker到底怎麼工作的?
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
參數詳解:
1、每次dataloader加載數據時:dataloader一次性創建num_worker個worker,(也可以說dataloader一次性創建num_worker個工作進程,worker也是普通的工作進程),並用batch_sampler將指定batch分配給指定worker,worker將它負責的batch加載進RAM。
然後,dataloader從RAM中找本輪迭代要用的batch,如果找到瞭,就使用。如果沒找到,就要num_worker個worker繼續加載batch到內存,直到dataloader在RAM中找到目標batch。一般情況下都是能找到的,因為batch_sampler指定batch時當然優先指定本輪要用的batch。
2、num_worker設置得大,好處是尋batch速度快,因為下一輪迭代的batch很可能在上一輪/上上一輪…迭代時已經加載好瞭。壞處是內存開銷大,也加重瞭CPU負擔(worker加載數據到RAM的進程是CPU復制的嘛)。num_workers的經驗設置值是自己電腦/服務器的CPU核心數,如果CPU很強、RAM也很充足,就可以設置得更大些。
3、如果num_worker設為0,意味著每一輪迭代時,dataloader不再有自主加載數據到RAM這一步驟(因為沒有worker瞭),而是在RAM中找batch,找不到時再加載相應的batch。缺點當然是速度更慢。
設置大小建議:
1、Dataloader的num_worker設置多少才合適,這個問題是很難有一個推薦的值。有以下幾個建議:
2、num_workers=0表示隻有主進程去加載batch數據,這個可能會是一個瓶頸。
3、num_workers = 1表示隻有一個worker進程用來加載batch數據,而主進程是不參與數據加載的。這樣速度也會很慢。
num_workers>0 表示隻有指定數量的worker進程去加載數據,主進程不參與。增加num_works也同時會增加cpu內存的消耗。所以num_workers的值依賴於 batch size和機器性能。
4、一般開始是將num_workers設置為等於計算機上的CPU數量
5、最好的辦法是緩慢增加num_workers,直到訓練速度不再提高,就停止增加num_workers的值。
補充:pytorch中Dataloader()中的num_workers設置問題
如果num_workers的值大於0,要在運行的部分放進__main__()函數裡,才不會有錯:
import numpy as np import torch from torch.autograd import Variable import torch.nn.functional import matplotlib.pyplot as plt import torch.utils.data as Data BATCH_SIZE=5 x=torch.linspace(1,10,10) y=torch.linspace(10,1,10) torch_dataset=Data.TensorDataset(x,y) loader=Data.DataLoader( dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, ) def main(): for epoch in range(3): for step,(batch_x,batch_y) in enumerate(loader): # training.... print('Epoch:',epoch,'| step:',step,'| batch x:',batch_x.numpy(), '| batch y:',batch_y.numpy()) if __name__=="__main__": main() ''' # 下面這樣直接運行會報錯: for epoch in range(3): for step,(batch_x,batch_y) in enumerate(loader): # training.... print('Epoch:',epoch,'| step:',step,'| batch x:',batch_x.numpy(), '| batch y:',batch_y.numpy() '''
以上為個人經驗,希望能給大傢一個參考,也希望大傢多多支持WalkonNet。
推薦閱讀:
- Pytorch數據讀取之Dataset和DataLoader知識總結
- pytorch鎖死在dataloader(訓練時卡死)
- 我對PyTorch dataloader裡的shuffle=True的理解
- pytorch中關於distributedsampler函數的使用
- 解決Pytorch dataloader時報錯每個tensor維度不一樣的問題