深入學習PyTorch中LSTM的輸入和輸出
LSTM參數
官方文檔給出的解釋為:
總共有七個參數,其中隻有前三個是必須的。由於大傢普遍使用PyTorch的DataLoader來形成批量數據,因此batch_first也比較重要。LSTM的兩個常見的應用場景為文本處理和時序預測,因此下面對每個參數我都會從這兩個方面來進行具體解釋。
- input_size:在文本處理中,由於一個單詞沒法參與運算,因此我們得通過Word2Vec來對單詞進行嵌入表示,將每一個單詞表示成一個向量,此時input_size=embedding_size。比如每個句子中有五個單詞,每個單詞用一個100維向量來表示,那麼這裡input_size=100;在時間序列預測中,比如需要預測負荷,每一個負荷都是一個單獨的值,都可以直接參與運算,因此並不需要將每一個負荷表示成一個向量,此時input_size=1。 但如果我們使用多變量進行預測,比如我們利用前24小時每一時刻的[負荷、風速、溫度、壓強、濕度、天氣、節假日信息]來預測下一時刻的負荷,那麼此時input_size=7。
- hidden_size:隱藏層節點個數。可以隨意設置。
- num_layers:層數。nn.LSTMCell與nn.LSTM相比,num_layers默認為1。
- batch_first:默認為False,意義見後文。
Inputs
關於LSTM的輸入,官方文檔給出的定義為:
可以看到,輸入由兩部分組成:input、(初始的隱狀態h_0,初始的單元狀態c_0)
其中input:
input(seq_len, batch_size, input_size)
- seq_len:在文本處理中,如果一個句子有7個單詞,則seq_len=7;在時間序列預測中,假設我們用前24個小時的負荷來預測下一時刻負荷,則seq_len=24。
- batch_size:一次性輸入LSTM中的樣本個數。在文本處理中,可以一次性輸入很多個句子;在時間序列預測中,也可以一次性輸入很多條數據。
- input_size
(h_0, c_0):
h_0(num_directions * num_layers, batch_size, hidden_size) c_0(num_directions * num_layers, batch_size, hidden_size)
h_0和c_0的shape一致。
- num_directions:如果是雙向LSTM,則num_directions=2;否則num_directions=1。num_layers:
- batch_size:
- hidden_size:
Outputs
關於LSTM的輸出,官方文檔給出的定義為:
可以看到,輸出也由兩部分組成:otput、(隱狀態h_n,單元狀態c_n)
其中output的shape為:
output(seq_len, batch_size, num_directions * hidden_size)
h_n和c_n的shape保持不變,參數解釋見前文。
batch_first
如果在初始化LSTM時令batch_first=True,那麼input和output的shape將由:
input(seq_len, batch_size, input_size) output(seq_len, batch_size, num_directions * hidden_size)
變為:
input(batch_size, seq_len, input_size) output(batch_size, seq_len, num_directions * hidden_size)
即batch_size提前。
案例
簡單搭建一個LSTM如下所示:
class LSTM(nn.Module): def __init__(self, input_size, hidden_size, num_layers, output_size, batch_size): super().__init__() self.input_size = input_size self.hidden_size = hidden_size self.num_layers = num_layers self.output_size = output_size self.num_directions = 1 # 單向LSTM self.batch_size = batch_size self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True) self.linear = nn.Linear(self.hidden_size, self.output_size) def forward(self, input_seq): batch_size, seq_len = input_seq[0], input_seq[1] h_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device) c_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device) # output(batch_size, seq_len, num_directions * hidden_size) output, _ = self.lstm(input_seq, (h_0, c_0)) # output(5, 30, 64) pred = self.linear(output) # (5, 30, 1) pred = pred[:, -1, :] # (5, 1) return pred
其中定義模型的代碼為:
self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True) self.linear = nn.Linear(self.hidden_size, self.output_size)
我們加上具體的數字:
self.lstm = nn.LSTM(self.input_size=1, self.hidden_size=64, self.num_layers=5, batch_first=True) self.linear = nn.Linear(self.hidden_size=64, self.output_size=1)
再看前向傳播:
def forward(self, input_seq): batch_size, seq_len = input_seq[0], input_seq[1] h_0 = torch.randn(self.num_directions * self.num_layers, batch_size, self.hidden_size).to(device) c_0 = torch.randn(self.num_directions * self.num_layers, batch_size, self.hidden_size).to(device) # input(batch_size, seq_len, input_size) # output(batch_size, seq_len, num_directions * hidden_size) output, _ = self.lstm(input_seq, (h_0, c_0)) # output(5, 30, 64) pred = self.linear(output) # (5, 30, 1) pred = pred[:, -1, :] # (5, 1) return pred
假設用前30個預測下一個,則seq_len=30,batch_size=5,由於設置瞭batch_first=True,因此,輸入到LSTM中的input的shape應該為:
input(batch_size, seq_len, input_size) = input(5, 30, 1)
經過DataLoader處理後的input_seq為:
input_seq(batch_size, seq_len, input_size) = input_seq(5, 30, 1)
然後將input_seq送入LSTM:
output, _ = self.lstm(input_seq, (h_0, c_0)) # output(5, 30, 64)
根據前文,output的shape為:
output(batch_size, seq_len, num_directions * hidden_size) = output(5, 30, 64)
全連接層的定義為:
self.linear = nn.Linear(self.hidden_size=64, self.output_size=1)
然後將output送入全連接層:
pred = self.linear(output) # pred(5, 30, 1)
得到的預測值shape為(5, 30, 1),由於輸出是輸入右移,我們隻需要取pred第二維度(time)中的最後一個數據:
pred = pred[:, -1, :] # (5, 1)
這樣,我們就得到瞭預測值,然後與label求loss,然後再反向更新參數即可。
到此這篇關於深入學習PyTorch中LSTM的輸入和輸出的文章就介紹到這瞭,更多相關PyTorch LSTM內容請搜索WalkonNet以前的文章或繼續瀏覽下面的相關文章希望大傢以後多多支持WalkonNet!
推薦閱讀:
- pytorch lstm gru rnn 得到每個state輸出的操作
- pytorch中使用LSTM詳解
- Pytorch實現LSTM案例總結學習
- Python中LSTM回歸神經網絡時間序列預測詳情
- python中的Pytorch建模流程匯總