-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
120 lines (95 loc) · 3.97 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#!usr/bin/env python
# -*- coding:utf-8 -*-
import os
import random
import logging
import argparse
import importlib
import platform
from pprint import pformat
import numpy as np
import torch
from agents.utils import *
# torch.backends.cudnn.enabled = True
# torch.backends.cudnn.benchmark = True
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # - %(name)s
logger = logging.getLogger(__file__)
device = torch.device('cuda' if torch.cuda.is_available() and platform.system() != 'Windows' else 'cpu')
logger.info("Device: {}".format(device))
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
setup_seed(42)
parser = argparse.ArgumentParser()
# agent
parser.add_argument("--agent", type=str, required=True,
help="Agent name")
parser.add_argument("--task", type=str, required=True,
help="Agent name")
# data
parser.add_argument("--dataset_path", type=str, default="data/catslu/hyps/map/",
help="Path or url of the dataset. If empty download accroding to dataset.")
parser.add_argument("--save_dir", type=str, default="checkpoint/")
parser.add_argument('--save_name', type=str, default="")
# training
parser.add_argument('--epochs', type=int, required=True)
parser.add_argument('--early_stop', default=-1, type=int)
parser.add_argument('--mode', type=str, default="train")
parser.add_argument('--lr_reduce_patience', default=-1, type=int)
parser.add_argument('--lr_decay', type=float, default=0.5)
# infer
parser.add_argument('--result_path', type=str, default="")
parser.add_argument('--infer_data', type=str, default="test")
def get_agent_task(opt):
agent_name = opt.get('agent')
task_name = opt.get('task')
# "agents.bert_agents.sequence_labeling"
trainer_module = importlib.import_module("agents." + agent_name + ".trainer")
trainer_class = getattr(trainer_module, "Trainer")
data_module = importlib.import_module("tasks." + task_name)
getdata_class = getattr(data_module, "get_datasets")
builddata_class = getattr(data_module, "build_dataset")
return trainer_class, getdata_class, builddata_class
parsed = vars(parser.parse_known_args()[0])
# trainer_class, getdata_class = AGENT_CLASSES[parsed.get('agent')]
trainer_class, getdata_class, builddata_class = get_agent_task(parsed)
trainer_class.add_cmdline_args(parser)
opt = parser.parse_args()
def main():
# my_module = importlib.import_module(module_name)
# model_class = getattr(my_module, class_name)
if not os.path.exists(opt.save_dir):
os.mkdir(opt.save_dir)
opt.best_checkpoint_path = opt.save_dir + opt.save_name + "_" + parsed.get('task') + "_" + parsed.get(
'agent') + '_best_model'
logger.info("Arguments: %s", pformat(opt))
trainer = trainer_class(opt, device)
datasets = getdata_class(opt.dataset_path)
for k, v in datasets.items():
trainer.load_data(k, v, builddata_class, infer=opt.mode == "infer")
if opt.mode == "train":
trainer.set_optim_schedule()
if opt.mode == "infer":
if os.path.exists(opt.best_checkpoint_path):
opt.checkpoint = opt.best_checkpoint_path
logger.info("load checkpoint from {} ".format(opt.checkpoint))
trainer.load(opt.checkpoint)
if opt.infer_data not in trainer.dataset:
raise Exception("%s does not exists in datasets" % opt.infer_data)
result = trainer.infer(opt.infer_data)
if opt.result_path:
save_json(result, opt.result_path)
else:
for e in range(opt.epochs):
trainer.train_epoch(e)
if trainer.patience >= opt.early_stop > 0:
break
trainer.evaluate(e, "valid")
if trainer.patience >= opt.early_stop > 0:
break
logger.info('Test performance {}'.format(trainer.test_performance))
if __name__ == '__main__':
main()