pytorch 一行代碼查看網絡參數總量的實現
大傢還是直接看代碼吧~
netG = Generator() print('# generator parameters:', sum(param.numel() for param in netG.parameters())) netD = Discriminator() print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))
補充:PyTorch查看網絡模型的參數量PARAMS和FLOPS等
在PyTorch中,可以使用torchstat這個庫來查看網絡模型的一些信息,包括總的參數量params、MAdd、顯卡內存占用量和FLOPs等。
示例代碼如下:
from torchstat import stat from torchvision.models import resnet50, resnet101, resnet152, resnext101_32x8d model = resnet50() stat(model, (3, 224, 224))
打印信息如下:
以上為個人經驗,希望能給大傢一個參考,也希望大傢多多支持WalkonNet。如有錯誤或未考慮完全的地方,望不吝賜教。
推薦閱讀:
- pytorch如何獲得模型的計算量和參數量
- pytorch查看網絡參數顯存占用量等操作
- pytorch 計算Parameter和FLOP的操作
- Pytorch教程內置模型源碼實現
- Pytorch 統計模型參數量的操作 param.numel()