python類參數定義及數據擴展方式unsqueeze/expand
類的參數定義
將conda環境設置為ai,conda activate ai
這個文件的由來:
由於在yolov1的pytorch實現的損失函數中,看到繼承瞭nn.Module,並且其中兩個參數不像c++那裡指定類型,那麼他們的類型是哪裡來的
這裡就是在探索這樣一件事
操作邏輯:
- 先在類中定義瞭構造函數以及一個自定義函數;
- 構造函數定義瞭屬性S、B,自定義函數引入兩個參數,對兩個參數進行調用
- 這裡就說明參數的結構是怎麼樣的,取決於參數被調用瞭什麼東西,比如這裡調用瞭
N = box1.size(0) M = box2.size(0)
說明瞭它是類似一個矩陣的東西,對應的box1的定義就是`torch.rand(10,4)
- 這裡就說明參數的結構是怎麼樣的,取決於參數被調用瞭什麼東西,比如這裡調用瞭
import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable #探究屬性S,B是如何產生的,以及box1、box2是如何產生的、如何調用 class yoloLoss(nn.Module): def __init__(self,S,B): self.S=S self.B=B def compute_iot(self,box1,box2): N = box1.size(0) #調用方式就表示瞭變量是什麼類型,這裡是一個張量,其中每個元素是一個tensor,所以是N*4的張量 M = box2.size(0) print(M,N) yoloLoss1 =yoloLoss(10, 11) yoloLoss1.compute_iot(torch.rand(10,4),torch.rand(11,4))
數據擴展
探究unsqueeze以及expand的使用方法,unsqueeze可以增加一個緯度,但是維度的siz隻是1而已,而expand就可以將數據進行復制,將數據變為n
# 獲得一開始的初始化數值:tensor([[a1,a2,a3]]) nn1=torch.rand(1,3) print(nn1) # unsqueeze是解壓的意思,在第i個維度上進行擴展,將其擴展為tensor([[[a1,a2,a3]]]) nn1=nn1.unsqueeze(0) print("*"*100) print(nn1) #利用expand對數據進行擴展 nn1=nn1.expand(1,3,3) print("*"*100) print(nn1)
到此這篇關於python類參數定義及數據擴展方式unsqueeze/expand的文章就介紹到這瞭,更多相關python unsqueeze/expand內容請搜索WalkonNet以前的文章或繼續瀏覽下面的相關文章希望大傢以後多多支持WalkonNet!
推薦閱讀:
- pytorch下的unsqueeze和squeeze的用法說明
- PyTorch中Tensor和tensor的區別及說明
- Win10操作系統中PyTorch虛擬環境配置+PyCharm配置
- Broadcast廣播機制在Pytorch Tensor Numpy中的使用詳解
- Python深度學習之Pytorch初步使用