pytorch中的廣播語義
pytorch的廣播語義(broadcasting semantics),和numpy的很像,所以可以先看看numpy的文檔:
1、什麼是廣播語義?
官方文檔有這樣一個解釋:
In short, if a PyTorch operation supports broadcast, then its Tensor arguments can be automatically expanded to be of equal sizes (without making copies of the data).
這句話的意思大概是:簡單的說,如果一個pytorch操作支持廣播,那麼它的Tensor參數可以自動的擴展為相同的尺寸(不需要復制數據)。
按照我的理解,應該是指算法計算過程中,不同的Tensor如果size
不同,但是符合一定的規則,那麼可以自動的進行維度擴展,來實現Tensor
的計算。在維度擴展的過程中,並不是真的把維度小的Tensor復制為和維度大的Tensor相同,因為這樣太浪費內存瞭。
2、廣播語義的規則
首先來看標準的情況,兩個Tensor的size相同,則可以直接計算:
x = torch.empty((4, 2, 3)) y = torch.empty((4, 2, 3)) print((x+y).size())
輸出:
torch.Size([4, 2, 3])
但是,如果兩個Tensor
的維度並不相同,pytorch也是可以根據下面的兩個法則進行計算:
- (1)Each tensor has at least one dimension.
- (2)When iterating over the dimension sizes, starting at the trailing dimension, the dimension sizes must either be equal, one of them is 1, or one of them does not exist.
- 每個
Tensor
至少有一個維度。- 迭代標註尺寸時,從後面的標註開始
第一個規則要求每個參與計算的Tensor
至少有一個維度,第二個規則是指在維度迭代時,從最後一個維度開始,可以有三種情況:
- 維度相等
- 其中一個維度是1
- 其中一個維度不存在
3、不符合廣播語義的例子
x = torch.empty((0, )) y = torch.empty((2, 3)) print((x + y).size())
輸出:
RuntimeError: The size of tensor a (0) must match the size of tensor b (3) at non-singleton dimension 1
這裡,不滿足第一個規則“每個參與計算的Tensor
至少有一個維度”。
x = torch.empty(5, 2, 4, 1) y = torch.empty(3, 1, 1) print((x + y).size())
輸出:
RuntimeError: The size of tensor a (2) must match
the size of tensor b (3) at non-singleton dimension 1
這裡,不滿足第二個規則,因為從最後的維度開始迭代的過程中,倒數第三個維度:x是2,y是3。這並不符合第二條規則的三種情況,所以不能使用廣播語義。
4、符合廣播語義的例子
x = torch.empty(5, 3, 4, 1) y = torch.empty(3, 1, 1) print((x + y).size())
輸出:
torch.Size([5, 3, 4, 1])
x是四維的,y是三維的,從最後一個維度開始迭代:
- 最後一維:x是1,y是1,滿足規則二
- 倒數第二維:x是4,y是1,滿足規則二
- 倒數第三維:x是3,y是3,滿足規則一
- 倒數第四維:x是5,y是0,滿足規則一
到此這篇關於pytorch中的廣播語義的文章就介紹到這瞭,更多相關pytorch廣播語義內容請搜索WalkonNet以前的文章或繼續瀏覽下面的相關文章希望大傢以後多多支持WalkonNet!
推薦閱讀:
- Pytorch中expand()的使用(擴展某個維度)
- Broadcast廣播機制在Pytorch Tensor Numpy中的使用詳解
- pytorch教程之Tensor的值及操作使用學習
- PyTorch中Tensor和tensor的區別及說明
- Python深度學習之Pytorch初步使用