在 pytorch 中實現計算圖和自動求導

前言:

今天聊一聊 pytorch 的計算圖和自動求導,我們先從一個簡單例子來看,下面是一個簡單函數建立瞭 yy 和 xx 之間的關系

然後我們結點和邊形式表示上面公式:

上面的式子可以用圖的形式表達,接下來我們用 torch 來計算 x 導數,首先我們創建一個 tensor 並且將其requires_grad設置為True表示隨後反向傳播會對其進行求導。

x = torch.tensor(3.,requires_grad=True)

然後寫出

y = 3*x**2 + 4*x + 2
y.backward()
x.grad

通過調用y.backward()來進行求導,這時就可以通過x.grad來獲得x的導數

x.requires_grad_(False)

可以通過requires_grad_x不參與到自動求導

for epoch in range(3):
  y = 3*x**2 + 4*x + 2
  y.backward()
  print(x.grad)
  x.grad.zero_()

如果這裡沒有調用x.grad_zero_()就是把每次求導數和上一次求導結果進行累加。

鏈式法則

相對於 z 對 x 求偏導時,我們可以將 y 看成常數,這樣 x 導數是 1 那麼

x = torch.tensor([1.,2.,3.],requires_grad=True)
y = x * 2 + 3
z = y **2
out = z.mean()
out.backward()
print(out) #tensor(51.6667, grad_fn=<MeanBackward0>)
print(x.grad) #tensor([ 6.6667, 9.3333, 12.0000])

對於一個簡單的網絡,我們可以手動計算梯度,但是如果擺在你面前的是一個有152 層的網絡怎麼辦?或者該網絡有多個分支。這時你的計算復雜程度可想而知。接下來會帶來更深入自動求導內部機制

到此這篇關於在 pytorch 中實現計算圖和自動求導的文章就介紹到這瞭,更多相關 pytorch 計算圖 內容請搜索WalkonNet以前的文章或繼續瀏覽下面的相關文章希望大傢以後多多支持WalkonNet!

推薦閱讀: