-
Notifications
You must be signed in to change notification settings - Fork 63
/
Copy pathrun_pong_no_frameskip_v4.py
155 lines (135 loc) · 4.29 KB
/
run_pong_no_frameskip_v4.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# -*- coding: utf-8 -*-
"""Train or test algorithms on PongNoFrameskip-v4.
- Author: Curt Park
- Contact: [email protected]
"""
import argparse
import datetime
from rl_algorithms import build_agent
from rl_algorithms.common.env.atari_wrappers import atari_env_generator
import rl_algorithms.common.helper_functions as common_utils
from rl_algorithms.utils import YamlConfig
def parse_args() -> argparse.Namespace:
# configurations
parser = argparse.ArgumentParser(description="Pytorch RL algorithms")
parser.add_argument(
"--seed", type=int, default=161, help="random seed for reproducibility"
)
parser.add_argument(
"--cfg-path",
type=str,
default="./configs/pong_no_frameskip_v4/dqn.yaml",
help="config path",
)
parser.add_argument(
"--integration-test",
dest="integration_test",
action="store_true",
help="for integration test",
)
parser.add_argument(
"--grad-cam",
dest="grad_cam",
action="store_true",
help="test mode with viewing Grad-CAM",
)
parser.add_argument(
"--test", dest="test", action="store_true", help="test mode (no training)"
)
parser.add_argument(
"--load-from",
type=str,
default=None,
help="load the saved model and optimizer at the beginning",
)
parser.add_argument(
"--off-render", dest="render", action="store_false", help="turn off rendering"
)
parser.add_argument(
"--render-after",
type=int,
default=0,
help="start rendering after the input number of episode",
)
parser.add_argument(
"--log", dest="log", action="store_true", help="turn on logging"
)
parser.add_argument("--save-period", type=int, default=20, help="save model period")
parser.add_argument(
"--episode-num", type=int, default=500, help="total episode num"
)
parser.add_argument(
"--max-episode-steps", type=int, default=None, help="max episode step"
)
parser.add_argument(
"--interim-test-num", type=int, default=5, help="interim test number"
)
parser.add_argument(
"--off-framestack",
dest="framestack",
action="store_false",
help="turn off framestack",
)
parser.add_argument(
"--saliency-map",
action="store_true",
help="save saliency map",
)
return parser.parse_args()
def env_generator(env_name, max_episode_steps, frame_stack):
def _thunk(rank: int):
env = atari_env_generator(env_name, max_episode_steps, frame_stack=frame_stack)
env.seed(777 + rank + 1)
return env
return _thunk
def main():
"""Main."""
args = parse_args()
# env initialization
env_name = "PongNoFrameskip-v4"
env_gen = env_generator(
env_name, args.max_episode_steps, frame_stack=args.framestack
)
env = env_gen(0)
# set a random seed
common_utils.set_random_seed(args.seed, env)
# run
NOWTIMES = datetime.datetime.now()
curr_time = NOWTIMES.strftime("%y%m%d_%H%M%S")
cfg = YamlConfig(dict(agent=args.cfg_path)).get_config_dict()
# If running integration test, simplify experiment
if args.integration_test:
cfg = common_utils.set_cfg_for_intergration_test(cfg)
env_info = dict(
name=env.spec.id,
observation_space=env.observation_space,
action_space=env.action_space,
is_atari=True,
env_generator=env_gen,
)
log_cfg = dict(agent=cfg.agent.type, curr_time=curr_time, cfg_path=args.cfg_path)
build_args = dict(
env=env,
env_info=env_info,
log_cfg=log_cfg,
is_test=args.test,
load_from=args.load_from,
is_render=args.render,
render_after=args.render_after,
is_log=args.log,
save_period=args.save_period,
episode_num=args.episode_num,
max_episode_steps=env.spec.max_episode_steps,
interim_test_num=args.interim_test_num,
)
agent = build_agent(cfg.agent, build_args)
if not args.test:
agent.train()
elif args.test and args.grad_cam:
agent.test_with_gradcam()
elif args.test and args.saliency_map:
agent.test_with_saliency_map()
else:
agent.test()
if __name__ == "__main__":
main()