-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathextract_token_reps.py
51 lines (40 loc) · 1.97 KB
/
extract_token_reps.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from roberta.helpers import load_shuffled_model
from fairseq.models.roberta import RobertaModel
from roberta.dataset import get_dataset
import argparse
from collections import defaultdict
import pickle as p
def main():
parser = argparse.ArgumentParser(description="generate token embeddings from corpus");
parser.add_argument('-d', "--dataset_name", type=str, default='ptb');
parser.add_argument('-m', "--model_path", type=str);
parser.add_argument('-o', "--out_folder", type=str);
parser.add_argument('-l', "--context_len", type=int, default=100);
parser.add_argument('-b', "--batch_size", type=int, default=64);
parser.add_argument('-no', "--no_contexts_limit", type=int, default=100);
arguments = parser.parse_args();
# load dataset
train_iter, val_iter, test_iter = get_dataset(arguments.dataset_name)
# Load pre-trained model (weights) and extract reps
roberta = load_shuffled_model(arguments.model_path)
# make default dictionary for storing extracted embeddings
embed_dict = defaultdict(list)
# iterate over train set and extract features
for line in train_iter:
if len(line.strip()) > 0:
try:
enc = roberta.extract_features_aligned_to_words(line.strip(), return_all_hiddens=True)
for tok in enc:
if len(embed_dict[str(tok)]) < arguments.no_contexts_limit:
#print('{:100}{} (...)'.format(str(tok), tok.vector[-1:].cpu().detach().numpy()))
embed_dict[str(tok)].append(list(tok.vector.cpu().detach().numpy()))
except:
continue
# write out embeds file
model_name = arguments.model_path.split('/')[-1]
out_file_name = open(arguments.out_folder + model_name + '-embs-' + arguments.dataset_name + \
'-cntx_count-' + str(arguments.no_contexts_limit), 'wb')
p.dump(embed_dict, out_file_name)
out_file_name.close()
if __name__ == '__main__':
main();