-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmain_train_eval.py
60 lines (44 loc) · 1.68 KB
/
main_train_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import argparse
import pprint
import numpy as np
np.set_printoptions(suppress=True)
from config import config
from config import update_config, create_logger
from trainer import *
def parse_args():
parser = argparse.ArgumentParser(description='Train Deep Compression System: CVQN')
parser.add_argument('--cfg',
help='experiment configure file name',
required=True,
type=str)
parser.add_argument('opts',
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER)
args = parser.parse_args()
update_config(config, args)
return args
def main():
args = parse_args()
logger, model_dir, tb_log_dir = create_logger(config, args.cfg, 'train')
logger.info(pprint.pformat(args))
logger.info(config)
if config['IMP_TYPE'] == 'predefine':
trainer = PredefineTrainer(config, logger, model_dir, tb_log_dir)
elif config['IMP_TYPE'] == 'RE':
trainer = RETrainer(config, logger, model_dir, tb_log_dir)
elif config['IMP_TYPE'] == 'SE':
trainer = SETrainer(config, logger, model_dir, tb_log_dir)
else:
raise NotImplementedError("Trainer type error.")
for epoch in range(config['TRAIN']['NUM_EPOCH']):
trainer.train()
trainer.eval()
if config['IMP_TYPE'] == 'RE' and (epoch + 1) % 10 == 0:
trainer.re_based_get_imp()
if epoch == config['TRAIN']['NUM_EPOCH'] - 1:
trainer.save_checkpoint('final.pth')
trainer.update_lr()
trainer.writer.close()
if __name__ == '__main__':
main()