PyTorch梯度下降反向傳播
前言:
反向傳播的目的是計算成本函數C對網絡中任意w或b的偏導數。一旦我們有瞭這些偏導數,我們將通過一些常數 α的乘積和該數量相對於成本函數的偏導數來更新網絡中的權重和偏差。這是流行的梯度下降算法。而偏導數給出瞭最大上升的方向。因此,關於反向傳播算法,我們繼續查看下文。
我們向相反的方向邁出瞭一小步——最大下降的方向,也就是將我們帶到成本函數的局部最小值的方向
如題:
意思是利用這個二次模型來預測數據,減小損失函數(MSE)的值。
代碼如下:
import torch import matplotlib.pyplot as plt import os os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" # 數據集 x_data = [1.0,2.0,3.0] y_data = [2.0,4.0,6.0] # 權重參數初始值均為1 w = torch.tensor([1.0,1.0,1.0]) w.requires_grad = True # 需要計算梯度 # 前向傳播 def forward(x): return w[0]*(x**2)+w[1]*x+w[2] # 計算損失 def loss(x,y): y_pred = forward(x) return (y_pred-y) ** 2 # 訓練模塊 print('predict (before tranining) ',4, forward(4).item()) epoch_list = [] w_list = [] loss_list = [] for epoch in range(1000): for x,y in zip(x_data,y_data): l = loss(x,y) l.backward() # 後向傳播 print('\tgrad: ',x,y,w.grad.data) w.data = w.data - 0.01 * w.grad.data # 梯度下降 w.grad.data.zero_() # 梯度清零操作 print('progress: ',epoch,l.item()) epoch_list.append(epoch) w_list.append(w.data) loss_list.append(l.item()) print('predict (after tranining) ',4, forward(4).item()) # 繪圖 plt.plot(epoch_list,loss_list,'b') plt.xlabel('Epoch') plt.ylabel('Loss') plt.grid() plt.show()
結果如下:
predict (before tranining) 4 21.0 grad: 1.0 2.0 tensor([2., 2., 2.]) grad: 2.0 4.0 tensor([22.8800, 11.4400, 5.7200]) grad: 3.0 6.0 tensor([77.0472, 25.6824, 8.5608]) progress: 0 18.321826934814453 grad: 1.0 2.0 tensor([-1.1466, -1.1466, -1.1466]) grad: 2.0 4.0 tensor([-15.5367, -7.7683, -3.8842]) grad: 3.0 6.0 tensor([-30.4322, -10.1441, -3.3814]) progress: 1 2.858394145965576 grad: 1.0 2.0 tensor([0.3451, 0.3451, 0.3451]) grad: 2.0 4.0 tensor([2.4273, 1.2137, 0.6068]) grad: 3.0 6.0 tensor([19.4499, 6.4833, 2.1611]) progress: 2 1.1675907373428345 grad: 1.0 2.0 tensor([-0.3224, -0.3224, -0.3224]) grad: 2.0 4.0 tensor([-5.8458, -2.9229, -1.4614]) grad: 3.0 6.0 tensor([-3.8829, -1.2943, -0.4314]) progress: 3 0.04653334245085716 grad: 1.0 2.0 tensor([0.0137, 0.0137, 0.0137]) grad: 2.0 4.0 tensor([-1.9141, -0.9570, -0.4785]) grad: 3.0 6.0 tensor([6.8557, 2.2852, 0.7617]) progress: 4 0.14506366848945618 grad: 1.0 2.0 tensor([-0.1182, -0.1182, -0.1182]) grad: 2.0 4.0 tensor([-3.6644, -1.8322, -0.9161]) grad: 3.0 6.0 tensor([1.7455, 0.5818, 0.1939]) progress: 5 0.009403289295732975 grad: 1.0 2.0 tensor([-0.0333, -0.0333, -0.0333]) grad: 2.0 4.0 tensor([-2.7739, -1.3869, -0.6935]) grad: 3.0 6.0 tensor([4.0140, 1.3380, 0.4460]) progress: 6 0.04972923547029495 grad: 1.0 2.0 tensor([-0.0501, -0.0501, -0.0501]) grad: 2.0 4.0 tensor([-3.1150, -1.5575, -0.7788]) grad: 3.0 6.0 tensor([2.8534, 0.9511, 0.3170]) progress: 7 0.025129113346338272 grad: 1.0 2.0 tensor([-0.0205, -0.0205, -0.0205]) grad: 2.0 4.0 tensor([-2.8858, -1.4429, -0.7215]) grad: 3.0 6.0 tensor([3.2924, 1.0975, 0.3658]) progress: 8 0.03345605731010437 grad: 1.0 2.0 tensor([-0.0134, -0.0134, -0.0134]) grad: 2.0 4.0 tensor([-2.9247, -1.4623, -0.7312]) grad: 3.0 6.0 tensor([2.9909, 0.9970, 0.3323]) progress: 9 0.027609655633568764 grad: 1.0 2.0 tensor([0.0033, 0.0033, 0.0033]) grad: 2.0 4.0 tensor([-2.8414, -1.4207, -0.7103]) grad: 3.0 6.0 tensor([3.0377, 1.0126, 0.3375]) progress: 10 0.02848036028444767 grad: 1.0 2.0 tensor([0.0148, 0.0148, 0.0148]) grad: 2.0 4.0 tensor([-2.8174, -1.4087, -0.7043]) grad: 3.0 6.0 tensor([2.9260, 0.9753, 0.3251]) progress: 11 0.02642466314136982 grad: 1.0 2.0 tensor([0.0280, 0.0280, 0.0280]) grad: 2.0 4.0 tensor([-2.7682, -1.3841, -0.6920]) grad: 3.0 6.0 tensor([2.8915, 0.9638, 0.3213]) progress: 12 0.025804826989769936 grad: 1.0 2.0 tensor([0.0397, 0.0397, 0.0397]) grad: 2.0 4.0 tensor([-2.7330, -1.3665, -0.6832]) grad: 3.0 6.0 tensor([2.8243, 0.9414, 0.3138]) progress: 13 0.02462013065814972 grad: 1.0 2.0 tensor([0.0514, 0.0514, 0.0514]) grad: 2.0 4.0 tensor([-2.6934, -1.3467, -0.6734]) grad: 3.0 6.0 tensor([2.7756, 0.9252, 0.3084]) progress: 14 0.023777369409799576 grad: 1.0 2.0 tensor([0.0624, 0.0624, 0.0624]) grad: 2.0 4.0 tensor([-2.6580, -1.3290, -0.6645]) grad: 3.0 6.0 tensor([2.7213, 0.9071, 0.3024]) progress: 15 0.0228563379496336 grad: 1.0 2.0 tensor([0.0731, 0.0731, 0.0731]) grad: 2.0 4.0 tensor([-2.6227, -1.3113, -0.6557]) grad: 3.0 6.0 tensor([2.6725, 0.8908, 0.2969]) progress: 16 0.022044027224183083 grad: 1.0 2.0 tensor([0.0833, 0.0833, 0.0833]) grad: 2.0 4.0 tensor([-2.5893, -1.2946, -0.6473]) grad: 3.0 6.0 tensor([2.6240, 0.8747, 0.2916]) progress: 17 0.02125072106719017 grad: 1.0 2.0 tensor([0.0931, 0.0931, 0.0931]) grad: 2.0 4.0 tensor([-2.5568, -1.2784, -0.6392]) grad: 3.0 6.0 tensor([2.5780, 0.8593, 0.2864]) progress: 18 0.020513182505965233 grad: 1.0 2.0 tensor([0.1025, 0.1025, 0.1025]) grad: 2.0 4.0 tensor([-2.5258, -1.2629, -0.6314]) grad: 3.0 6.0 tensor([2.5335, 0.8445, 0.2815]) progress: 19 0.019810274243354797 grad: 1.0 2.0 tensor([0.1116, 0.1116, 0.1116]) grad: 2.0 4.0 tensor([-2.4958, -1.2479, -0.6239]) grad: 3.0 6.0 tensor([2.4908, 0.8303, 0.2768]) progress: 20 0.019148115068674088 grad: 1.0 2.0 tensor([0.1203, 0.1203, 0.1203]) grad: 2.0 4.0 tensor([-2.4669, -1.2335, -0.6167]) grad: 3.0 6.0 tensor([2.4496, 0.8165, 0.2722]) progress: 21 0.018520694226026535 grad: 1.0 2.0 tensor([0.1286, 0.1286, 0.1286]) grad: 2.0 4.0 tensor([-2.4392, -1.2196, -0.6098]) grad: 3.0 6.0 tensor([2.4101, 0.8034, 0.2678]) progress: 22 0.017927465960383415 grad: 1.0 2.0 tensor([0.1367, 0.1367, 0.1367]) grad: 2.0 4.0 tensor([-2.4124, -1.2062, -0.6031]) grad: 3.0 6.0 tensor([2.3720, 0.7907, 0.2636]) progress: 23 0.01736525259912014 grad: 1.0 2.0 tensor([0.1444, 0.1444, 0.1444]) grad: 2.0 4.0 tensor([-2.3867, -1.1933, -0.5967]) grad: 3.0 6.0 tensor([2.3354, 0.7785, 0.2595]) progress: 24 0.016833148896694183 grad: 1.0 2.0 tensor([0.1518, 0.1518, 0.1518]) grad: 2.0 4.0 tensor([-2.3619, -1.1810, -0.5905]) grad: 3.0 6.0 tensor([2.3001, 0.7667, 0.2556]) progress: 25 0.01632905937731266 grad: 1.0 2.0 tensor([0.1589, 0.1589, 0.1589]) grad: 2.0 4.0 tensor([-2.3380, -1.1690, -0.5845]) grad: 3.0 6.0 tensor([2.2662, 0.7554, 0.2518]) progress: 26 0.01585075818002224 grad: 1.0 2.0 tensor([0.1657, 0.1657, 0.1657]) grad: 2.0 4.0 tensor([-2.3151, -1.1575, -0.5788]) grad: 3.0 6.0 tensor([2.2336, 0.7445, 0.2482]) progress: 27 0.015397666022181511 grad: 1.0 2.0 tensor([0.1723, 0.1723, 0.1723]) grad: 2.0 4.0 tensor([-2.2929, -1.1465, -0.5732]) grad: 3.0 6.0 tensor([2.2022, 0.7341, 0.2447]) progress: 28 0.014967591501772404 grad: 1.0 2.0 tensor([0.1786, 0.1786, 0.1786]) grad: 2.0 4.0 tensor([-2.2716, -1.1358, -0.5679]) grad: 3.0 6.0 tensor([2.1719, 0.7240, 0.2413]) progress: 29 0.014559715054929256 grad: 1.0 2.0 tensor([0.1846, 0.1846, 0.1846]) grad: 2.0 4.0 tensor([-2.2511, -1.1255, -0.5628]) grad: 3.0 6.0 tensor([2.1429, 0.7143, 0.2381]) progress: 30 0.014172340743243694 grad: 1.0 2.0 tensor([0.1904, 0.1904, 0.1904]) grad: 2.0 4.0 tensor([-2.2313, -1.1157, -0.5578]) grad: 3.0 6.0 tensor([2.1149, 0.7050, 0.2350]) progress: 31 0.013804304413497448 grad: 1.0 2.0 tensor([0.1960, 0.1960, 0.1960]) grad: 2.0 4.0 tensor([-2.2123, -1.1061, -0.5531]) grad: 3.0 6.0 tensor([2.0879, 0.6960, 0.2320]) progress: 32 0.013455045409500599 grad: 1.0 2.0 tensor([0.2014, 0.2014, 0.2014]) grad: 2.0 4.0 tensor([-2.1939, -1.0970, -0.5485]) grad: 3.0 6.0 tensor([2.0620, 0.6873, 0.2291]) progress: 33 0.013122711330652237 grad: 1.0 2.0 tensor([0.2065, 0.2065, 0.2065]) grad: 2.0 4.0 tensor([-2.1763, -1.0881, -0.5441]) grad: 3.0 6.0 tensor([2.0370, 0.6790, 0.2263]) progress: 34 0.01280694268643856 grad: 1.0 2.0 tensor([0.2114, 0.2114, 0.2114]) grad: 2.0 4.0 tensor([-2.1592, -1.0796, -0.5398]) grad: 3.0 6.0 tensor([2.0130, 0.6710, 0.2237]) progress: 35 0.012506747618317604 grad: 1.0 2.0 tensor([0.2162, 0.2162, 0.2162]) grad: 2.0 4.0 tensor([-2.1428, -1.0714, -0.5357]) grad: 3.0 6.0 tensor([1.9899, 0.6633, 0.2211]) progress: 36 0.012220758944749832 grad: 1.0 2.0 tensor([0.2207, 0.2207, 0.2207]) grad: 2.0 4.0 tensor([-2.1270, -1.0635, -0.5317]) grad: 3.0 6.0 tensor([1.9676, 0.6559, 0.2186]) progress: 37 0.01194891706109047 grad: 1.0 2.0 tensor([0.2251, 0.2251, 0.2251]) grad: 2.0 4.0 tensor([-2.1118, -1.0559, -0.5279]) grad: 3.0 6.0 tensor([1.9462, 0.6487, 0.2162]) progress: 38 0.011689926497638226 grad: 1.0 2.0 tensor([0.2292, 0.2292, 0.2292]) grad: 2.0 4.0 tensor([-2.0971, -1.0485, -0.5243]) grad: 3.0 6.0 tensor([1.9255, 0.6418, 0.2139]) progress: 39 0.01144315768033266 grad: 1.0 2.0 tensor([0.2333, 0.2333, 0.2333]) grad: 2.0 4.0 tensor([-2.0829, -1.0414, -0.5207]) grad: 3.0 6.0 tensor([1.9057, 0.6352, 0.2117]) progress: 40 0.011208509095013142 grad: 1.0 2.0 tensor([0.2371, 0.2371, 0.2371]) grad: 2.0 4.0 tensor([-2.0693, -1.0346, -0.5173]) grad: 3.0 6.0 tensor([1.8865, 0.6288, 0.2096]) progress: 41 0.0109840864315629 grad: 1.0 2.0 tensor([0.2408, 0.2408, 0.2408]) grad: 2.0 4.0 tensor([-2.0561, -1.0280, -0.5140]) grad: 3.0 6.0 tensor([1.8681, 0.6227, 0.2076]) progress: 42 0.010770938359200954 grad: 1.0 2.0 tensor([0.2444, 0.2444, 0.2444]) grad: 2.0 4.0 tensor([-2.0434, -1.0217, -0.5108]) grad: 3.0 6.0 tensor([1.8503, 0.6168, 0.2056]) progress: 43 0.010566935874521732 grad: 1.0 2.0 tensor([0.2478, 0.2478, 0.2478]) grad: 2.0 4.0 tensor([-2.0312, -1.0156, -0.5078]) grad: 3.0 6.0 tensor([1.8332, 0.6111, 0.2037]) progress: 44 0.010372749529778957 grad: 1.0 2.0 tensor([0.2510, 0.2510, 0.2510]) grad: 2.0 4.0 tensor([-2.0194, -1.0097, -0.5048]) grad: 3.0 6.0 tensor([1.8168, 0.6056, 0.2019]) progress: 45 0.010187389329075813 grad: 1.0 2.0 tensor([0.2542, 0.2542, 0.2542]) grad: 2.0 4.0 tensor([-2.0080, -1.0040, -0.5020]) grad: 3.0 6.0 tensor([1.8009, 0.6003, 0.2001]) progress: 46 0.010010283440351486 grad: 1.0 2.0 tensor([0.2572, 0.2572, 0.2572]) grad: 2.0 4.0 tensor([-1.9970, -0.9985, -0.4992]) grad: 3.0 6.0 tensor([1.7856, 0.5952, 0.1984]) progress: 47 0.00984097272157669 grad: 1.0 2.0 tensor([0.2600, 0.2600, 0.2600]) grad: 2.0 4.0 tensor([-1.9864, -0.9932, -0.4966]) grad: 3.0 6.0 tensor([1.7709, 0.5903, 0.1968]) progress: 48 0.009679674170911312 grad: 1.0 2.0 tensor([0.2628, 0.2628, 0.2628]) grad: 2.0 4.0 tensor([-1.9762, -0.9881, -0.4940]) grad: 3.0 6.0 tensor([1.7568, 0.5856, 0.1952]) progress: 49 0.009525291621685028 grad: 1.0 2.0 tensor([0.2655, 0.2655, 0.2655]) grad: 2.0 4.0 tensor([-1.9663, -0.9832, -0.4916]) grad: 3.0 6.0 tensor([1.7431, 0.5810, 0.1937]) progress: 50 0.00937769003212452 grad: 1.0 2.0 tensor([0.2680, 0.2680, 0.2680]) grad: 2.0 4.0 tensor([-1.9568, -0.9784, -0.4892]) grad: 3.0 6.0 tensor([1.7299, 0.5766, 0.1922]) progress: 51 0.009236648678779602 grad: 1.0 2.0 tensor([0.2704, 0.2704, 0.2704]) grad: 2.0 4.0 tensor([-1.9476, -0.9738, -0.4869]) grad: 3.0 6.0 tensor([1.7172, 0.5724, 0.1908]) progress: 52 0.00910158734768629 grad: 1.0 2.0 tensor([0.2728, 0.2728, 0.2728]) grad: 2.0 4.0 tensor([-1.9387, -0.9694, -0.4847]) grad: 3.0 6.0 tensor([1.7050, 0.5683, 0.1894]) progress: 53 0.00897257961332798 grad: 1.0 2.0 tensor([0.2750, 0.2750, 0.2750]) grad: 2.0 4.0 tensor([-1.9301, -0.9651, -0.4825]) grad: 3.0 6.0 tensor([1.6932, 0.5644, 0.1881]) progress: 54 0.008848887868225574 grad: 1.0 2.0 tensor([0.2771, 0.2771, 0.2771]) grad: 2.0 4.0 tensor([-1.9219, -0.9609, -0.4805]) grad: 3.0 6.0 tensor([1.6819, 0.5606, 0.1869]) progress: 55 0.008730598725378513 grad: 1.0 2.0 tensor([0.2792, 0.2792, 0.2792]) grad: 2.0 4.0 tensor([-1.9139, -0.9569, -0.4785]) grad: 3.0 6.0 tensor([1.6709, 0.5570, 0.1857]) progress: 56 0.00861735362559557 grad: 1.0 2.0 tensor([0.2811, 0.2811, 0.2811]) grad: 2.0 4.0 tensor([-1.9062, -0.9531, -0.4765]) grad: 3.0 6.0 tensor([1.6604, 0.5535, 0.1845]) progress: 57 0.008508718572556973 grad: 1.0 2.0 tensor([0.2830, 0.2830, 0.2830]) grad: 2.0 4.0 tensor([-1.8987, -0.9493, -0.4747]) grad: 3.0 6.0 tensor([1.6502, 0.5501, 0.1834]) progress: 58 0.008404706604778767 grad: 1.0 2.0 tensor([0.2848, 0.2848, 0.2848]) grad: 2.0 4.0 tensor([-1.8915, -0.9457, -0.4729]) grad: 3.0 6.0 tensor([1.6404, 0.5468, 0.1823]) progress: 59 0.008305158466100693 grad: 1.0 2.0 tensor([0.2865, 0.2865, 0.2865]) grad: 2.0 4.0 tensor([-1.8845, -0.9423, -0.4711]) grad: 3.0 6.0 tensor([1.6309, 0.5436, 0.1812]) progress: 60 0.00820931326597929 grad: 1.0 2.0 tensor([0.2882, 0.2882, 0.2882]) grad: 2.0 4.0 tensor([-1.8778, -0.9389, -0.4694]) grad: 3.0 6.0 tensor([1.6218, 0.5406, 0.1802]) progress: 61 0.008117804303765297 grad: 1.0 2.0 tensor([0.2898, 0.2898, 0.2898]) grad: 2.0 4.0 tensor([-1.8713, -0.9356, -0.4678]) grad: 3.0 6.0 tensor([1.6130, 0.5377, 0.1792]) progress: 62 0.008029798977077007 grad: 1.0 2.0 tensor([0.2913, 0.2913, 0.2913]) grad: 2.0 4.0 tensor([-1.8650, -0.9325, -0.4662]) grad: 3.0 6.0 tensor([1.6045, 0.5348, 0.1783]) progress: 63 0.007945418357849121 grad: 1.0 2.0 tensor([0.2927, 0.2927, 0.2927]) grad: 2.0 4.0 tensor([-1.8589, -0.9294, -0.4647]) grad: 3.0 6.0 tensor([1.5962, 0.5321, 0.1774]) progress: 64 0.007864190265536308 grad: 1.0 2.0 tensor([0.2941, 0.2941, 0.2941]) grad: 2.0 4.0 tensor([-1.8530, -0.9265, -0.4632]) grad: 3.0 6.0 tensor([1.5884, 0.5295, 0.1765]) progress: 65 0.007786744274199009 grad: 1.0 2.0 tensor([0.2954, 0.2954, 0.2954]) grad: 2.0 4.0 tensor([-1.8473, -0.9236, -0.4618]) grad: 3.0 6.0 tensor([1.5807, 0.5269, 0.1756]) progress: 66 0.007711691781878471 grad: 1.0 2.0 tensor([0.2967, 0.2967, 0.2967]) grad: 2.0 4.0 tensor([-1.8417, -0.9209, -0.4604]) grad: 3.0 6.0 tensor([1.5733, 0.5244, 0.1748]) progress: 67 0.007640169933438301 grad: 1.0 2.0 tensor([0.2979, 0.2979, 0.2979]) grad: 2.0 4.0 tensor([-1.8364, -0.9182, -0.4591]) grad: 3.0 6.0 tensor([1.5662, 0.5221, 0.1740]) progress: 68 0.007570972666144371 grad: 1.0 2.0 tensor([0.2991, 0.2991, 0.2991]) grad: 2.0 4.0 tensor([-1.8312, -0.9156, -0.4578]) grad: 3.0 6.0 tensor([1.5593, 0.5198, 0.1733]) progress: 69 0.007504733745008707 grad: 1.0 2.0 tensor([0.3002, 0.3002, 0.3002]) grad: 2.0 4.0 tensor([-1.8262, -0.9131, -0.4566]) grad: 3.0 6.0 tensor([1.5527, 0.5176, 0.1725]) progress: 70 0.007440924644470215 grad: 1.0 2.0 tensor([0.3012, 0.3012, 0.3012]) grad: 2.0 4.0 tensor([-1.8214, -0.9107, -0.4553]) grad: 3.0 6.0 tensor([1.5463, 0.5154, 0.1718]) progress: 71 0.007379599846899509 grad: 1.0 2.0 tensor([0.3022, 0.3022, 0.3022]) grad: 2.0 4.0 tensor([-1.8167, -0.9083, -0.4542]) grad: 3.0 6.0 tensor([1.5401, 0.5134, 0.1711]) progress: 72 0.007320486940443516 grad: 1.0 2.0 tensor([0.3032, 0.3032, 0.3032]) grad: 2.0 4.0 tensor([-1.8121, -0.9060, -0.4530]) grad: 3.0 6.0 tensor([1.5341, 0.5114, 0.1705]) progress: 73 0.007263725157827139 grad: 1.0 2.0 tensor([0.3041, 0.3041, 0.3041]) grad: 2.0 4.0 tensor([-1.8077, -0.9038, -0.4519]) grad: 3.0 6.0 tensor([1.5283, 0.5094, 0.1698]) progress: 74 0.007209045812487602 grad: 1.0 2.0 tensor([0.3050, 0.3050, 0.3050]) grad: 2.0 4.0 tensor([-1.8034, -0.9017, -0.4508]) grad: 3.0 6.0 tensor([1.5227, 0.5076, 0.1692]) progress: 75 0.007156429346650839 grad: 1.0 2.0 tensor([0.3058, 0.3058, 0.3058]) grad: 2.0 4.0 tensor([-1.7992, -0.8996, -0.4498]) grad: 3.0 6.0 tensor([1.5173, 0.5058, 0.1686]) progress: 76 0.007105532102286816 grad: 1.0 2.0 tensor([0.3066, 0.3066, 0.3066]) grad: 2.0 4.0 tensor([-1.7952, -0.8976, -0.4488]) grad: 3.0 6.0 tensor([1.5121, 0.5040, 0.1680]) progress: 77 0.00705681974068284 grad: 1.0 2.0 tensor([0.3073, 0.3073, 0.3073]) grad: 2.0 4.0 tensor([-1.7913, -0.8956, -0.4478]) grad: 3.0 6.0 tensor([1.5070, 0.5023, 0.1674]) progress: 78 0.007009552326053381 grad: 1.0 2.0 tensor([0.3081, 0.3081, 0.3081]) grad: 2.0 4.0 tensor([-1.7875, -0.8937, -0.4469]) grad: 3.0 6.0 tensor([1.5021, 0.5007, 0.1669]) progress: 79 0.006964194122701883 grad: 1.0 2.0 tensor([0.3087, 0.3087, 0.3087]) grad: 2.0 4.0 tensor([-1.7838, -0.8919, -0.4459]) grad: 3.0 6.0 tensor([1.4974, 0.4991, 0.1664]) progress: 80 0.006920332089066505 grad: 1.0 2.0 tensor([0.3094, 0.3094, 0.3094]) grad: 2.0 4.0 tensor([-1.7802, -0.8901, -0.4450]) grad: 3.0 6.0 tensor([1.4928, 0.4976, 0.1659]) progress: 81 0.006878111511468887 grad: 1.0 2.0 tensor([0.3100, 0.3100, 0.3100]) grad: 2.0 4.0 tensor([-1.7767, -0.8883, -0.4442]) grad: 3.0 6.0 tensor([1.4884, 0.4961, 0.1654]) progress: 82 0.006837360095232725 grad: 1.0 2.0 tensor([0.3106, 0.3106, 0.3106]) grad: 2.0 4.0 tensor([-1.7733, -0.8867, -0.4433]) grad: 3.0 6.0 tensor([1.4841, 0.4947, 0.1649]) progress: 83 0.006797831039875746 grad: 1.0 2.0 tensor([0.3111, 0.3111, 0.3111]) grad: 2.0 4.0 tensor([-1.7700, -0.8850, -0.4425]) grad: 3.0 6.0 tensor([1.4800, 0.4933, 0.1644]) progress: 84 0.006760062649846077 grad: 1.0 2.0 tensor([0.3117, 0.3117, 0.3117]) grad: 2.0 4.0 tensor([-1.7668, -0.8834, -0.4417]) grad: 3.0 6.0 tensor([1.4759, 0.4920, 0.1640]) progress: 85 0.006723103579133749 grad: 1.0 2.0 tensor([0.3122, 0.3122, 0.3122]) grad: 2.0 4.0 tensor([-1.7637, -0.8818, -0.4409]) grad: 3.0 6.0 tensor([1.4720, 0.4907, 0.1636]) progress: 86 0.00668772729113698 grad: 1.0 2.0 tensor([0.3127, 0.3127, 0.3127]) grad: 2.0 4.0 tensor([-1.7607, -0.8803, -0.4402]) grad: 3.0 6.0 tensor([1.4682, 0.4894, 0.1631]) progress: 87 0.006653300020843744 grad: 1.0 2.0 tensor([0.3131, 0.3131, 0.3131]) grad: 2.0 4.0 tensor([-1.7577, -0.8789, -0.4394]) grad: 3.0 6.0 tensor([1.4646, 0.4882, 0.1627]) progress: 88 0.0066203586757183075 grad: 1.0 2.0 tensor([0.3135, 0.3135, 0.3135]) grad: 2.0 4.0 tensor([-1.7548, -0.8774, -0.4387]) grad: 3.0 6.0 tensor([1.4610, 0.4870, 0.1623]) progress: 89 0.0065881176851689816 grad: 1.0 2.0 tensor([0.3139, 0.3139, 0.3139]) grad: 2.0 4.0 tensor([-1.7520, -0.8760, -0.4380]) grad: 3.0 6.0 tensor([1.4576, 0.4859, 0.1620]) progress: 90 0.0065572685562074184 grad: 1.0 2.0 tensor([0.3143, 0.3143, 0.3143]) grad: 2.0 4.0 tensor([-1.7493, -0.8747, -0.4373]) grad: 3.0 6.0 tensor([1.4542, 0.4847, 0.1616]) progress: 91 0.0065271081402897835 grad: 1.0 2.0 tensor([0.3147, 0.3147, 0.3147]) grad: 2.0 4.0 tensor([-1.7466, -0.8733, -0.4367]) grad: 3.0 6.0 tensor([1.4510, 0.4837, 0.1612]) progress: 92 0.00649801641702652 grad: 1.0 2.0 tensor([0.3150, 0.3150, 0.3150]) grad: 2.0 4.0 tensor([-1.7441, -0.8720, -0.4360]) grad: 3.0 6.0 tensor([1.4478, 0.4826, 0.1609]) progress: 93 0.0064699104987084866 grad: 1.0 2.0 tensor([0.3153, 0.3153, 0.3153]) grad: 2.0 4.0 tensor([-1.7415, -0.8708, -0.4354]) grad: 3.0 6.0 tensor([1.4448, 0.4816, 0.1605]) progress: 94 0.006442630663514137 grad: 1.0 2.0 tensor([0.3156, 0.3156, 0.3156]) grad: 2.0 4.0 tensor([-1.7391, -0.8695, -0.4348]) grad: 3.0 6.0 tensor([1.4418, 0.4806, 0.1602]) progress: 95 0.006416172254830599 grad: 1.0 2.0 tensor([0.3159, 0.3159, 0.3159]) grad: 2.0 4.0 tensor([-1.7366, -0.8683, -0.4342]) grad: 3.0 6.0 tensor([1.4389, 0.4796, 0.1599]) progress: 96 0.006390606984496117 grad: 1.0 2.0 tensor([0.3161, 0.3161, 0.3161]) grad: 2.0 4.0 tensor([-1.7343, -0.8671, -0.4336]) grad: 3.0 6.0 tensor([1.4361, 0.4787, 0.1596]) progress: 97 0.0063657015562057495 grad: 1.0 2.0 tensor([0.3164, 0.3164, 0.3164]) grad: 2.0 4.0 tensor([-1.7320, -0.8660, -0.4330]) grad: 3.0 6.0 tensor([1.4334, 0.4778, 0.1593]) progress: 98 0.0063416799530386925 grad: 1.0 2.0 tensor([0.3166, 0.3166, 0.3166]) grad: 2.0 4.0 tensor([-1.7297, -0.8649, -0.4324]) grad: 3.0 6.0 tensor([1.4308, 0.4769, 0.1590]) progress: 99 0.00631808303296566 predict (after tranining) 4 8.544171333312988
損失值隨著迭代次數的增加呈遞減趨勢,如下圖所示:
可以看出:x=4時的預測值約為8.5,與真實值8有所差距,可通過提高迭代次數或者調整學習率、初始參數等方法來減小差距。
參考文獻:
- [1] https://www.bilibili.com/video/av93365242
到此這篇關於PyTorch反向傳播的文章就介紹到這瞭,更多相關PyTorch反向傳播內容請搜索WalkonNet以前的文章或繼續瀏覽下面的相關文章希望大傢以後多多支持WalkonNet!
推薦閱讀:
- PyTorch實現線性回歸詳細過程
- 使用Pytorch實現two-head(多輸出)模型的操作
- Pytorch寫數字識別LeNet模型
- PyTorch 如何自動計算梯度
- pytorch中.numpy()、.item()、.cpu()、.detach()以及.data的使用方法