beam search及pytorch的實現方式
主要記錄兩種不同的beam search版本
版本一
使用類似層次遍歷的方式進行搜索,用隊列進行維護,每次循環對當前層的所有節點進行搜索,這些節點每個分別對應topk個節點作為下一層候選節點,取所有候選節點的前tok個作為下一層節點加入隊列
bfs with width constraint. 啟發式搜索的一種. 屬於貪心算法. 如果k -> inf,那麼等價於bfs.
從根節點開始(),選取所有可能(大概幾萬個)裡面概率最大的k個,拓展為下一層節點.
然後在這k個節點裡面,其可能拓展的所有節點中(一般是k * 幾萬個),再選取概率最大的k個(註意這裡的概率是累乘,即從根節點到該節點的概率乘積)拓展. 這裡拓展的k個子節點,其父節點可以是上一層的k個,也可以隻是其中一部分,甚至全部出自其中一個節點. 以此類推.
這樣形成的是一棵每層都是k個節點樹(除瞭根節點、末尾,和候選者不足k個的情況).
一般概率取log,避免值過小.
舉個例子:k=2
<sos> 選取概率最大的三個, “i”: 0.6, “he”: 0.4. 其他單詞忽略不計
拓展一共有4個 (1)“i”後面接,假設概率最大的是”love”: 0.7, “like”: 0.3 其他單詞忽略不計(2)“he”後面接:假設概率最大的是”hates”: 0.9, “loves”: 0.1 其他單詞忽略不計; 這樣4種可能中,到這裡 “i love”概率是0.6 * 0.7 = 0.42, “i like”概率是0.6 * 0.3 = 0.18, “he hates”概率是0.4 * 0.9 = 0.36, “he loves”概率是0.4 * 0.1 = 0.04; 選取概率最大的兩個,“i love”和”he hates”.
下一層拓展仍為4個 (1) “i love”後面接 ,假設概率最大是 “you”:0.9, 其他單詞加起來0.1;(2)“he hates”後面接,假設概率最大的是”her”:0.8, “himself”:0.1, 其他單詞加起來0.1; 那麼”i love you”概率為 0.42 * 0.9 = 0.378; “he hates her”概率為0.36*0.8 = 0.228,其他不用算瞭都小於這個值. 最後也選取2個概率最大的: “i love you”和 “he hates her”
下一層拓展, “i love you”應該拓展兩個子節點,發現””概率0.99,其他單詞加起來0.01;“he hates her”應該拓展兩個子節點,發現””概率0.99,其他單詞加起來0.01;所以概率最大的是”i love you “和”he hates you “. 因兩個分支均遇到,均結束搜索.
最後在兩個當中選擇概率最大的 “i love you “. 結束
代碼是從一個項目中截取的,隻選取瞭關鍵內容,pytorch實現:
class Node(object): def __init__(self, hidden, previous_node, decoder_input, attn, log_prob, length): self.hidden = hidden self.previous_node = previous_node self.decoder_input = decoder_input self.attn = attn self.log_prob = log_prob self.length = length def beam_search(beam_width): ... root = Node(hidden, None, decoder_input, None, 0, 1) q = Queue() q.put(root) end_nodes = [] #最終節點的位置,用於回溯 while not q.empty(): candidates = [] #每一層的可能被拓展的節點,隻需選取每個父節點的兒子節點中概率最大的k個即可 for _ in range(q.qsize()): node = q.get() decoder_input = node.decoder_input hidden = node.hidden # 搜索終止條件 if decoder_input.item() == EOS or node.length >= 50: end_nodes.append(node) continue log_prob, hidden, attn = decoder( decoder_input, hidden, encoder_input ) log_prob, indices = log_prob.topk(beam_width) #選取某個父節點的兒子節點概率最大的k個 for k in range(beam_width): index = indices[k].unsqueeze(0) log_p = log_prob[k].item() child = Node(hidden, node, index, attn, node.log_prob + log_p, node.length + 1) candidates.append((node.log_prob + log_p, child)) #建立候選兒子節點,註意這裡概率需要累計 candidates = sorted(candidates, key=lambda x:x[0], reverse=True) #候選節點排序 length = min(len(candidates), beam_width) #取前k個,如果不足k個,則全部入選 for i in range(length): q.put(candidates[i][1]) # 後面是回溯, 省略 ...
版本二
不進行層次遍歷,而是每次從整個隊列中拿出概率最大的節點出隊(優先隊列)進行搜索,將該節點的topk加入優先隊列,循環終止的條件是節點所在位置對應長度達到限制或隊列節點個數超過限制
import operator import torch import torch.nn as nn import torch.nn.functional as F from queue import PriorityQueue device = torch.device("cuda" if torch.cuda.is_available() else "cpu") SOS_token = 0 EOS_token = 1 MAX_LENGTH = 50 class DecoderRNN(nn.Module): def __init__(self, embedding_size, hidden_size, output_size, cell_type, dropout=0.1): ''' Illustrative decoder ''' super(DecoderRNN, self).__init__() self.hidden_size = hidden_size self.cell_type = cell_type self.embedding = nn.Embedding(num_embeddings=output_size, embedding_dim=embedding_size, ) self.rnn = nn.GRU(embedding_size, hidden_size, bidirectional=True, dropout=dropout, batch_first=False) self.dropout_rate = dropout self.out = nn.Linear(hidden_size, output_size) def forward(self, input, hidden, not_used): embedded = self.embedding(input).transpose(0, 1) # [B,1] -> [ 1, B, D] embedded = F.dropout(embedded, self.dropout_rate) output = embedded # batch_first=False, output維度為 (seq_len, batch_size, num_directions * hidden_size) = [1, batch_size, 2*hidden_size] output, hidden = self.rnn(output, hidden) out = self.out(output.squeeze(0)) # output維度為 [batch_size, vocab_size] # hidden維度為 [num_layers * num_directions, batch_size, hidden_size] output = F.log_softmax(out, dim=1) return output, hidden class BeamSearchNode(object): def __init__(self, hiddenstate, previousNode, wordId, logProb, length): ''' :param hiddenstate: :param previousNode: :param wordId: :param logProb: :param length: ''' self.h = hiddenstate self.prevNode = previousNode self.wordid = wordId self.logp = logProb self.leng = length def eval(self, alpha=1.0): reward = 0 # Add here a function for shaping a reward return self.logp / float(self.leng - 1 + 1e-6) + alpha * reward decoder = DecoderRNN() def beam_decode(target_tensor, decoder_hiddens, encoder_outputs=None): ''' :param target_tensor: target indexes tensor of shape [B, T] where B is the batch size and T is the maximum length of the output sentence :param decoder_hidden: input tensor of shape [1, B, H] for start of the decoding :param encoder_outputs: if you are using attention mechanism you can pass encoder outputs, [T, B, H] where T is the maximum length of input sentence :return: decoded_batch ''' beam_width = 10 topk = 1 # how many sentence do you want to generate decoded_batch = [] # decoding goes sentence by sentence for idx in range(target_tensor.size(0)): if isinstance(decoder_hiddens, tuple): # LSTM case decoder_hidden = (decoder_hiddens[0][:,idx, :].unsqueeze(0),decoder_hiddens[1][:,idx, :].unsqueeze(0)) else: decoder_hidden = decoder_hiddens[:, idx, :].unsqueeze(0) encoder_output = encoder_outputs[:,idx, :].unsqueeze(1) # Start with the start of the sentence token decoder_input = torch.LongTensor([[SOS_token]], device=device) # Number of sentence to generate endnodes = [] number_required = min((topk + 1), topk - len(endnodes)) # starting node - hidden vector, previous node, word id, logp, length node = BeamSearchNode(decoder_hidden, None, decoder_input, 0, 1) nodes = PriorityQueue() # start the queue nodes.put((-node.eval(), node)) qsize = 1 # start beam search while True: # give up when decoding takes too long if qsize > 2000: break # fetch the best node score, n = nodes.get() decoder_input = n.wordid decoder_hidden = n.h if n.wordid.item() == EOS_token and n.prevNode != None: endnodes.append((score, n)) # if we reached maximum # of sentences required if len(endnodes) >= number_required: break else: continue # output維度為 [batch_size, vocab_size] # hidden維度為 [num_layers * num_directions, batch_size, hidden_size] # decode for one step using decoder decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_output) # PUT HERE REAL BEAM SEARCH OF TOP # log_prov, indexes維度為 [batch_size, beam_width] = [1, beam_width] log_prob, indexes = torch.topk(decoder_output, beam_width, dim=1) nextnodes = [] for new_k in range(beam_width): # decoded_t: [1,1],通過view(1,-1)將數字tensor變為維度為[1,1]的tensor decoded_t = indexes[0][new_k].view(1, -1) # log_p, int log_p = log_prob[0][new_k].item() # item()將tensor數字變為int node = BeamSearchNode(decoder_hidden, n, decoded_t, n.logp + log_p, n.leng + 1) score = -node.eval() nextnodes.append((score, node)) # put them into queue for i in range(len(nextnodes)): score, nn = nextnodes[i] nodes.put((score, nn)) # increase qsize qsize += len(nextnodes) - 1 # choose nbest paths, back trace them if len(endnodes) == 0: endnodes = [nodes.get() for _ in range(topk)] utterances = [] for score, n in sorted(endnodes, key=operator.itemgetter(0)): utterance = [] utterance.append(n.wordid) # back trace while n.prevNode != None: n = n.prevNode utterance.append(n.wordid) utterance = utterance[::-1] utterances.append(utterance) decoded_batch.append(utterances) return decoded_batch def greedy_decode(decoder_hidden, encoder_outputs, target_tensor): ''' :param target_tensor: target indexes tensor of shape [B, T] where B is the batch size and T is the maximum length of the output sentence :param decoder_hidden: input tensor of shape [1, B, H] for start of the decoding :param encoder_outputs: if you are using attention mechanism you can pass encoder outputs, [T, B, H] where T is the maximum length of input sentence :return: decoded_batch ''' batch_size, seq_len = target_tensor.size() decoded_batch = torch.zeros((batch_size, MAX_LENGTH)) decoder_input = torch.LongTensor([[SOS_token] for _ in range(batch_size)], device=device) for t in range(MAX_LENGTH): decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs) topv, topi = decoder_output.data.topk(1) # get candidates topi = topi.view(-1) decoded_batch[:, t] = topi decoder_input = topi.detach().view(-1, 1) return decoded_batch
補充:beam search 簡單例子實現及講解
看代碼吧~
from math import log from numpy import array from numpy import argmax # beam search def beam_search_decoder(data, k): sequences = [[list(), 1.0]] # walk over each step in sequence for row in data: all_candidates = list() # expand each current candidate for i in range(len(sequences)): seq, score = sequences[i] for j in range(len(row)): candidate = [seq + [j], score * -log(row[j])] all_candidates.append(candidate) # order all candidates by score ordered = sorted(all_candidates, key=lambda tup :tup[1]) # select k best sequences = ordered[:k] return sequences def greedy_decoder(data): # index for largest probability each row return [argmax(s) for s in data] # define a sequence of 10 words over a vocab of 5 words data = [[0.1, 0.2, 0.3, 0.4, 0.5], [0.5, 0.4, 0.3, 0.2, 0.1], [0.1, 0.2, 0.3, 0.4, 0.5], [0.5, 0.4, 0.3, 0.2, 0.1], [0.1, 0.2, 0.3, 0.4, 0.5], [0.5, 0.4, 0.3, 0.2, 0.1], [0.1, 0.2, 0.3, 0.4, 0.5], [0.5, 0.4, 0.3, 0.2, 0.1], [0.1, 0.2, 0.3, 0.4, 0.5], [0.5, 0.4, 0.3, 0.2, 0.1]] data = array(data) # decode sequence result = beam_search_decoder(data, 3) # print result for seq in result: print(seq)
每次循環sequences的值
[[[4], 0.6931471805599453], [[3], 0.916290731874155], [[2], 1.2039728043259361]]
[[[4, 0], 0.4804530139182014], [[4, 1], 0.6351243373717793], [[3, 0], 0.6351243373717793]]
[[[4, 0, 4], 0.33302465198892944], [[4, 0, 3], 0.4402346437542523], [[4, 1, 4], 0.4402346437542523]]
最終print的結果
[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108]
[[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397]
[[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397]
以上為個人經驗,希望能給大傢一個參考,也希望大傢多多支持WalkonNet。
推薦閱讀:
- pytorch lstm gru rnn 得到每個state輸出的操作
- 對pytorch中不定長序列補齊的操作
- 深入學習PyTorch中LSTM的輸入和輸出
- pytorch中常用的乘法運算及相關的運算符(@和*)
- 返回最大值的index pytorch方式