Python人工智能學習PyTorch實現WGAN示例詳解
1.GAN簡述
在GAN中,有兩個模型,一個是生成模型,用於生成樣本,一個是判別模型,用於判斷樣本是真還是假。但由於在GAN中,使用的JS散度去計算損失值,很容易導致梯度彌散的情況,從而無法進行梯度下降更新參數,於是在WGAN中,引入瞭Wasserstein Distance,使得訓練變得穩定。本文中我們以服從高斯分佈的數據作為樣本。
2.生成器模塊
這裡從2維數據,最終生成2維,主要目的是為瞭可視化比較方便。也就是說,在生成模型中,我們輸入雜亂無章的2維的數據,通過訓練之後,可以生成一個贗品,這個贗品在模仿高斯分佈。
3.判別器模塊
判別器同樣輸入的是2維的數據。比如我們上面的生成器,生成瞭一個2維的贗品,輸入判別器之後,它能夠最終輸出一個sigmoid轉換後的結果,相當於是一個概率,從而判別,這個贗品到底能不能達到以假亂真的程度。
4.數據生成模塊
由於我們使用的是高斯模型,因此,直接生成我們需要的數據即可。我們在這個模塊中,生成8個服從高斯分佈的數據。
5.判別器訓練
由於使用JS散度去計算損失的時候,會很容易出現梯度極小,接近於0的情況,會使得梯度下降無法進行,因此計算損失的時候,使用瞭Wasserstein Distance,去度量兩個分佈之間的差異。因此我們假如瞭梯度懲罰的因子。
其中,梯度懲罰的模塊如下:
6.生成器訓練
這裡的訓練是緊接著判別器訓練的。也就是說,在一個周期裡面,先訓練判別器,再訓練生成器。
7.結果可視化
通過visdom可視化損失值,通過matplotlib可視化分佈的預測結果。
以上就是人工智能學習PyTorch實現WGAN示例詳解的詳細內容,更多關於PyTorch實現WGAN的資料請關註WalkonNet其它相關文章!
推薦閱讀:
- 人工智能學習pyTorch自建數據集及可視化結果實現過程
- 人工智能學習Pytorch梯度下降優化示例詳解
- PyTorch 可視化工具TensorBoard和Visdom
- Pytorch BCELoss和BCEWithLogitsLoss的使用
- Python深度學習pytorch卷積神經網絡LeNet