pytorch中Schedule與warmup_steps的用法說明
1. lr_scheduler相關
lr_scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=num_train_optimization_steps)
其中args.warmup_steps可以認為是耐心系數
num_train_optimization_steps為模型參數的總更新次數
一般來說:
num_train_optimization_steps = int(total_train_examples / args.train_batch_size / args.gradient_accumulation_steps)
Schedule用來調節學習率,拿線性變換調整來說,下面代碼中,step是當前迭代次數。
def lr_lambda(self, step): # 線性變換,返回的是某個數值x,然後返回到類LambdaLR中,最終返回old_lr*x if step < self.warmup_steps: # 增大學習率 return float(step) / float(max(1, self.warmup_steps)) # 減小學習率 return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))
在實際運行中,lr_scheduler.step()先將lr初始化為0. 在第一次參數更新時,此時step=1,lr由0變為初始值initial_lr;在第二次更新時,step=2,上面代碼中生成某個實數alpha,新的lr=initial_lr *alpha;在第三次更新時,新的lr是在initial_lr基礎上生成,即新的lr=initial_lr *alpha。
其中warmup_steps可以認為是lr調整的耐心系數。
由於有warmup_steps存在,lr先慢慢增加,超過warmup_steps時,lr再慢慢減小。
在實際中,由於訓練剛開始時,訓練數據計算出的grad可能與期望方向相反,所以此時采用較小的lr,隨著迭代次數增加,lr線性增大,增長率為1/warmup_steps;迭代次數等於warmup_steps時,學習率為初始設定的學習率;迭代次數超過warmup_steps時,學習率逐步衰減,衰減率為1/(total-warmup_steps),再進行微調。
2. gradient_accumulation_steps相關
gradient_accumulation_steps通過累計梯度來解決本地顯存不足問題。
假設原來的batch_size=6,樣本總量為24,gradient_accumulation_steps=2
那麼參數更新次數=24/6=4
現在,減小batch_size=6/2=3,參數更新次數不變=24/3/2=4
在梯度反傳時,每gradient_accumulation_steps次進行一次梯度更新,之前照常利用loss.backward()計算梯度。
補充:pytorch學習筆記 -optimizer.step()和scheduler.step()
optimizer.step()和scheduler.step()的區別
optimizer.step()通常用在每個mini-batch之中,而scheduler.step()通常用在epoch裡面,但是不絕對,可以根據具體的需求來做。隻有用瞭optimizer.step(),模型才會更新,而scheduler.step()是對lr進行調整。
通常我們有
optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum = 0.9) scheduler = lr_scheduler.StepLR(optimizer, step_size = 100, gamma = 0.1) model = net.train(model, loss_function, optimizer, scheduler, num_epochs = 100)
在scheduler的step_size表示scheduler.step()每調用step_size次,對應的學習率就會按照策略調整一次。
所以如果scheduler.step()是放在mini-batch裡面,那麼step_size指的是經過這麼多次迭代,學習率改變一次。
以上為個人經驗,希望能給大傢一個參考,也希望大傢多多支持WalkonNet。
推薦閱讀:
- 聊聊pytorch中Optimizer與optimizer.step()的用法
- pytorch 實現L2和L1正則化regularization的操作
- Pytorch中的學習率衰減及其用法詳解
- Pytorch中求模型準確率的兩種方法小結
- 解決pytorch 模型復制的一些問題