-
Notifications
You must be signed in to change notification settings - Fork 2
/
train_dqn.py
37 lines (27 loc) · 1.02 KB
/
train_dqn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# -------------------------------------------------------------------#
# Released under the MIT license (https://opensource.org/licenses/MIT)
# Contact: [email protected]
# Enhancement Copyright 2016, Mrinal Haloi
# -------------------------------------------------------------------#
import random
import os
import tensorflow as tf
from core.solver import Solver
from env.environment import GymEnvironment, SimpleGymEnvironment
from config.config import cfg
# Set random seed
tf.set_random_seed(123)
random.seed(12345)
def main(_):
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.4)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
if cfg.env_type == 'simple':
env = SimpleGymEnvironment(cfg)
else:
env = GymEnvironment(cfg)
if not os.path.exists('/tmp/model_dir'):
os.mkdir('/tmp/model_dir')
solver = Solver(cfg, env, sess, '/tmp/model_dir')
solver.train()
if __name__ == '__main__':
tf.app.run()