-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathargs.py
92 lines (84 loc) · 3.19 KB
/
args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import argparse
import torch
from utils import get_device
def get_args():
parser = argparse.ArgumentParser(description='configuration')
parser.add_argument("--name",
default="1",
type=str,
help="Experiment name, for logging and saving models")
parser.add_argument("--do_train",
action='store_true',
default=False,
help="Whether to run training.")
parser.add_argument("--do_eval",
action='store_true',
default=False,
help="Whether to run eval on the test set.")
parser.add_argument("--data_dir",
default="./data/jave",
type=str,
help="The input data dir.")
parser.add_argument("--word_vocab",
default="word_vocab.json",
type=str,
help="The vocabulary file.")
parser.add_argument("--ontology_vocab",
default="attribute_vocab.json",
type=str,
help="The ontology class file.")
parser.add_argument("--tokenizer", default="char",
type=str,
help="The tokenizer type.")
parser.add_argument('--seed',
type=int,
default=42,
help="The random seed for initialization")
parser.add_argument('--gpu_ids',
type=str,
default='2',
help="The GPU ids")
# Hyperparameters
# Batch size
parser.add_argument("--batch_size",
default=512,
type=int,
help="Total batch size for training.")
# Learning rate
parser.add_argument("--lr",
default=2e-4,
type=float,
help="The initial learning rate for Adam.")
# Epochs
parser.add_argument("--epoch",
default=40,
type=int,
help="Total number of training epochs to perform.")
# emb_dim
parser.add_argument("--emb_dim",
default=200,
type=int,
help="The dimension of the embedding")
# encode_dim
parser.add_argument("--encode_dim",
default=200,
type=int,
help="The dimension of the encoding")
# skip subject
parser.add_argument("--skip_subject",
default=True,
type=bool,
help="Whether to skip the subject")
args = parser.parse_args()
if args.gpu_ids == "":
n_gpu = 0
device = torch.device('cpu')
else:
gpu_ids = [int(device_id) for device_id in args.gpu_ids.split()]
args.gpu_ids = gpu_ids
device, n_gpu = get_device(gpu_ids[0])
if n_gpu > 1:
n_gpu = len(gpu_ids)
args.device = device
args.n_gpu = n_gpu
return args