淺談tensorflow與pytorch的相互轉換
本文以一段代碼為例,簡單介紹一下tensorflow與pytorch的相互轉換(主要是tensorflow轉pytorch),可能介紹的沒有那麼詳細,僅供參考。
由於本人隻熟悉pytorch,而對tensorflow一知半解,而代碼經常遇到tensorflow,而我希望使用pytorch,因此簡單介紹一下tensorflow轉pytorch,可能存在諸多錯誤,希望輕噴~
1.變量預定義
在TensorFlow的世界裡,變量的定義和初始化是分開的。
tensorflow中一般都是在開頭預定義變量,聲明其數據類型、形狀等,在執行的時候再賦具體的值,如下圖所示,而pytorch用到時才會定義,定義和變量初始化是合在一起的。
2.創建變量並初始化
tensorflow中利用tf.Variable創建變量並進行初始化,而pytorch中使用torch.tensor創建變量並進行初始化,如下圖所示。
3.語句執行
在TensorFlow的世界裡,變量的定義和初始化是分開的,所有關於圖變量的賦值和計算都要通過tf.Session的run來進行。
sess.run([G_solver, G_loss_temp, MSE_loss], feed_dict = {X: X_mb, M: M_mb, H: H_mb})
而在pytorch中,並不需要通過run進行,賦值完瞭直接計算即可。
4.tensor
pytorch運算時要創建完的numpy數組轉為tensor,如下:
if use_gpu is True: X_mb = torch.tensor(X_mb, device="cuda") M_mb = torch.tensor(M_mb, device="cuda") H_mb = torch.tensor(H_mb, device="cuda") else: X_mb = torch.tensor(X_mb) M_mb = torch.tensor(M_mb) H_mb = torch.tensor(H_mb)
最後運行完還要將tensor數據類型轉換回numpy數組:
if use_gpu is True: imputed_data=imputed_data.cpu().detach().numpy() else: imputed_data=imputed_data.detach().numpy()
而tensorflow中不需要這種操作。
5.其他函數
在tensorflow中包含諸多函數是pytorch中沒有的,但是都可以在其他庫中找到類似,具體如下表所示。
tensorflow中函數 | pytorch中代替(所在庫) | 參數區別 |
---|---|---|
tf.sqrt | np.sqrt(numpy) | 完全相同 |
tf.random_normal | np.random.normal(numpy) | tf.random_normal(shape = size, stddev = xavier_stddev) np.random.normal(size = size, scale = xavier_stddev) |
tf.concat | torch.cat(torch) | inputs = tf.concat(values = [x, m], axis = 1) inputs = torch.cat(dim=1, tensors=[x, m]) |
tf.nn.relu | F.relu(torch.nn.functional) | 完全相同 |
tf.nn.sigmoid | torch.sigmoid(torch) | 完全相同 |
tf.matmul | torch.matmul(torch) | 完全相同 |
tf.reduce_mean | torch.mean(torch) | 完全相同 |
tf.log | torch.log(torch) | 完全相同 |
tf.zeros | np.zeros | 完全相同 |
tf.train.AdamOptimizer | torch.optim.Adam(torch) | optimizer_D = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D) optimizer_D = torch.optim.Adam(params=theta_D) |
到此這篇關於淺談tensorflow與pytorch的相互轉換的文章就介紹到這瞭,更多相關tensorflow與pytorch的相互轉換內容請搜索WalkonNet以前的文章或繼續瀏覽下面的相關文章希望大傢以後多多支持WalkonNet!
推薦閱讀:
- pytorch 實現L2和L1正則化regularization的操作
- pytorch教程之Tensor的值及操作使用學習
- Pytorch相關知識介紹與應用
- Python深度學習之Pytorch初步使用
- Pytorch實現圖像識別之數字識別(附詳細註釋)