在 pytorch 中實現計算圖和自動求導
前言:
今天聊一聊 pytorch 的計算圖和自動求導,我們先從一個簡單例子來看,下面是一個簡單函數建立瞭 yy 和 xx 之間的關系
然後我們結點和邊形式表示上面公式:
上面的式子可以用圖的形式表達,接下來我們用 torch 來計算 x 導數,首先我們創建一個 tensor 並且將其requires_grad
設置為True
表示隨後反向傳播會對其進行求導。
x = torch.tensor(3.,requires_grad=True)
然後寫出
y = 3*x**2 + 4*x + 2
y.backward() x.grad
通過調用y.backward()
來進行求導,這時就可以通過x.grad
來獲得x
的導數
x.requires_grad_(False)
可以通過requires_grad_
讓x
不參與到自動求導
for epoch in range(3): y = 3*x**2 + 4*x + 2 y.backward() print(x.grad) x.grad.zero_()
如果這裡沒有調用x.grad_zero_()
就是把每次求導數和上一次求導結果進行累加。
鏈式法則
相對於 z 對 x 求偏導時,我們可以將 y 看成常數,這樣 x 導數是 1 那麼
x = torch.tensor([1.,2.,3.],requires_grad=True)
y = x * 2 + 3 z = y **2
out = z.mean() out.backward()
print(out) #tensor(51.6667, grad_fn=<MeanBackward0>)
print(x.grad) #tensor([ 6.6667, 9.3333, 12.0000])
對於一個簡單的網絡,我們可以手動計算梯度,但是如果擺在你面前的是一個有152 層的網絡怎麼辦?或者該網絡有多個分支。這時你的計算復雜程度可想而知。接下來會帶來更深入自動求導內部機制
到此這篇關於在 pytorch 中實現計算圖和自動求導的文章就介紹到這瞭,更多相關 pytorch 計算圖 內容請搜索WalkonNet以前的文章或繼續瀏覽下面的相關文章希望大傢以後多多支持WalkonNet!
推薦閱讀:
- PyTorch 如何自動計算梯度
- Pytorch中的backward()多個loss函數用法
- pytorch_detach 切斷網絡反傳方式
- PyTorch梯度下降反向傳播
- pytorch 如何打印網絡回傳梯度