Pytorch實現List Tensor轉Tensor,reshape拼接等操作
持續更新一些常用的Tensor操作,比如List,Numpy,Tensor之間的轉換,Tensor的拼接,維度的變換等操作。
其它Tensor操作如 einsum等見:待更新。
用到兩個函數:
torch.cat
torch.stack
一、List Tensor轉Tensor (torch.cat)
// An highlighted block >>> t1 = torch.FloatTensor([[1,2],[5,6]]) >>> t2 = torch.FloatTensor([[3,4],[7,8]]) >>> l = [] >>> l.append(t1) >>> l.append(t2) >>> ta = torch.cat(l,dim=0) >>> ta = torch.cat(l,dim=0).reshape(2,2,2) >>> tb = torch.cat(l,dim=1).reshape(2,2,2) >>> ta tensor([[[1., 2.], [5., 6.]], [[3., 4.], [7., 8.]]]) >>> tb tensor([[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]])
高維tensor
** 如果理解瞭2D to 3DTensor,以此類推,不難理解3D to 4D,看下面代碼即可明白:**
>>> t1 = torch.range(1,8).reshape(2,2,2) >>> t2 = torch.range(11,18).reshape(2,2,2) >>> l = [] >>> l.append(t1) >>> l.append(t2) >>> torch.cat(l,dim=2).reshape(2,2,2,2) tensor([[[[ 1., 2.], [11., 12.]], [[ 3., 4.], [13., 14.]]], [[[ 5., 6.], [15., 16.]], [[ 7., 8.], [17., 18.]]]]) >>> torch.cat(l,dim=1).reshape(2,2,2,2) tensor([[[[ 1., 2.], [ 3., 4.]], [[11., 12.], [13., 14.]]], [[[ 5., 6.], [ 7., 8.]], [[15., 16.], [17., 18.]]]]) >>> torch.cat(l,dim=0).reshape(2,2,2,2) tensor([[[[ 1., 2.], [ 3., 4.]], [[ 5., 6.], [ 7., 8.]]], [[[11., 12.], [13., 14.]], [[15., 16.], [17., 18.]]]])
二、List Tensor轉Tensor (torch.stack)
代碼:
import torch t1 = torch.FloatTensor([[1,2],[5,6]]) t2 = torch.FloatTensor([[3,4],[7,8]]) l = [t1, t2] t3 = torch.stack(l, dim=2) print(t3.shape) print(t3) ## output: ## torch.Size([2, 2, 2]) ## tensor([[[1., 3.], ## [2., 4.]], ## [[5., 7.], ## [6., 8.]]])
以上為個人經驗,希望能給大傢一個參考,也希望大傢多多支持WalkonNet。
推薦閱讀:
- pytorch tensor計算三通道均值方式
- PyTorch中Tensor和tensor的區別及說明
- 聊聊Pytorch torch.cat與torch.stack的區別
- 淺談pytorch中stack和cat的及to_tensor的坑
- Broadcast廣播機制在Pytorch Tensor Numpy中的使用詳解