-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathbeam_search.pysnip
79 lines (66 loc) · 2.93 KB
/
beam_search.pysnip
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
https://github.com/roeeaharoni/morphological-reinflection/blob/master/src/task2_ms2s.py
# beam search
# initialize the decoder rnn
s_0 = decoder_rnn.initial_state()
# set prev_output_vec for first lstm step as BEGIN_WORD
prev_output_vec = char_lookup[alphabet_index[BEGIN_WORD]]
i = 0
beam_width = BEAM_WIDTH
beam = {}
beam[-1] = [([BEGIN_WORD], 1.0, s_0)] # (sequence, probability, decoder_rnn)
final_states = []
# run the decoder through the sequence and predict characters
while i < MAX_PREDICTION_LEN and len(beam[i - 1]) > 0:
# at each stage:
# create all expansions from the previous beam:
new_hypos = []
for hypothesis in beam[i - 1]:
seq, hyp_prob, prefix_decoder = hypothesis
last_hypo_char = seq[-1]
# cant expand finished sequences
if last_hypo_char == END_WORD:
continue
# expand from the last character of the hypothesis
try:
prev_output_vec = char_lookup[alphabet_index[last_hypo_char]]
except KeyError:
# not a character
# print 'impossible to expand, key error'# + str(seq)
continue
# if the lemma is finished, pad with epsilon chars
if i < len(source_word):
blstm_output = blstm_outputs[i]
try:
source_word_input_char_vec = char_lookup[alphabet_index[source_word[i]]]
except KeyError:
# handle unseen characters
source_word_input_char_vec = char_lookup[alphabet_index[UNK]]
else:
source_word_input_char_vec = char_lookup[alphabet_index[EPSILON]]
blstm_output = blstm_outputs[source_word_char_vecs_len - 1]
decoder_input = concatenate([blstm_output,
prev_output_vec,
source_word_input_char_vec,
char_lookup[alphabet_index[str(i)]],
feats_input])
# prepare input vector and perform LSTM step
s = prefix_decoder.add_input(decoder_input)
# compute softmax probs
decoder_rnn_output = s.output()
probs = softmax(R * decoder_rnn_output + bias)
probs = probs.vec_value()
# expand - create new hypos
for index, p in enumerate(probs):
new_seq = list(seq)
new_seq.append(inverse_alphabet_index[index])
new_prob = hyp_prob * p
if new_seq[-1] == END_WORD:
# if found a complete sequence - add to final states
final_states.append((new_seq[1:-1], new_prob))
else:
new_hypos.append((new_seq, new_prob, s))
# add the expansions with the largest probability to the beam together with their score and prefix rnn state
new_probs = [p for (s, p, r) in new_hypos]
argmax_indices = common.argmax(new_probs, n=beam_width)
beam[i] = [new_hypos[l] for l in argmax_indices]
i += 1