forked from OpenNMT/OpenNMT-py
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtranslate.py
executable file
·138 lines (112 loc) · 4.6 KB
/
translate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#!/usr/bin/env python
from __future__ import division
from builtins import bytes
import os
import argparse
import math
import codecs
import torch
import onmt
import onmt.IO
import opts
from itertools import takewhile, count
from tools.visualize_att import plot_attention
try:
from itertools import zip_longest
except ImportError:
from itertools import izip_longest as zip_longest
parser = argparse.ArgumentParser(
description='translate.py',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
opts.add_md_help_argument(parser)
opts.translate_opts(parser)
opt = parser.parse_args()
if opt.batch_size != 1:
print("WARNING: -batch_size isn't supported currently, "
"we set it to 1 for now!")
opt.batch_size = 1
def report_score(name, score_total, words_total):
print("%s AVG SCORE: %.4f, %s PPL: %.4f" % (
name, score_total / words_total,
name, math.exp(-score_total/words_total)))
def get_src_words(src_indices, index2str):
words = []
raw_words = (index2str[i] for i in src_indices)
words = takewhile(lambda w: w != onmt.IO.PAD_WORD, raw_words)
return " ".join(words)
def main():
dummy_parser = argparse.ArgumentParser(description='train.py')
opts.model_opts(dummy_parser)
dummy_opt = dummy_parser.parse_known_args([])[0]
opt.cuda = opt.gpu > -1
if opt.cuda:
torch.cuda.set_device(opt.gpu)
translator = onmt.Translator(opt, dummy_opt.__dict__)
out_file = codecs.open(opt.output, 'w', 'utf-8')
pred_score_total, pred_words_total = 0, 0
gold_score_total, gold_words_total = 0, 0
if opt.dump_beam != "":
import json
translator.initBeamAccum()
data = onmt.IO.ONMTDataset(
opt.src, opt.tgt, translator.fields,
use_filter_pred=False)
test_data = onmt.IO.OrderedIterator(
dataset=data, device=opt.gpu,
batch_size=opt.batch_size, train=False, sort=False,
shuffle=False)
counter = count(1)
for batch in test_data:
pred_batch, gold_batch, pred_scores, gold_scores, attn, src \
= translator.translate(batch, data)
pred_score_total += sum(score[0] for score in pred_scores)
pred_words_total += sum(len(x[0]) for x in pred_batch)
if opt.tgt:
gold_score_total += sum(gold_scores)
gold_words_total += sum(len(x) for x in batch.tgt[1:])
# z_batch: an iterator over the predictions, their scores,
# the gold sentence, its score, and the source sentence for each
# sentence in the batch. It has to be zip_longest instead of
# plain-old zip because the gold_batch has length 0 if the target
# is not included.
z_batch = zip_longest(
pred_batch, gold_batch,
pred_scores, gold_scores,
(sent.squeeze(1) for sent in src.split(1, dim=1)))
for pred_sents, gold_sent, pred_score, gold_score, src_sent in z_batch:
n_best_preds = [" ".join(pred) for pred in pred_sents[:opt.n_best]]
out_file.write('\n'.join(n_best_preds))
out_file.write('\n')
out_file.flush()
if opt.verbose:
sent_number = next(counter)
words = get_src_words(
src_sent, translator.fields["src"].vocab.itos)
os.write(1, bytes('\nSENT %d: %s\n' %
(sent_number, words), 'UTF-8'))
best_pred = n_best_preds[0]
best_score = pred_score[0]
os.write(1, bytes('PRED %d: %s\n' %
(sent_number, best_pred), 'UTF-8'))
print("PRED SCORE: %.4f" % best_score)
if opt.tgt:
tgt_sent = ' '.join(gold_sent)
os.write(1, bytes('GOLD %d: %s\n' %
(sent_number, tgt_sent), 'UTF-8'))
print("GOLD SCORE: %.4f" % gold_score)
if len(n_best_preds) > 1:
print('\nBEST HYP:')
for score, sent in zip(pred_score, n_best_preds):
os.write(1, bytes("[%.4f] %s\n" % (score, sent),
'UTF-8'))
if opt.attn_debug:
# print(attn)
plot_attention(attn,words,pred_sents[0],filename=sent_number)
report_score('PRED', pred_score_total, pred_words_total)
if opt.tgt:
report_score('GOLD', gold_score_total, gold_words_total)
if opt.dump_beam:
json.dump(translator.beam_accum,
codecs.open(opt.dump_beam, 'w', 'utf-8'))
if __name__ == "__main__":
main()