forked from huawei-noah/HEBO
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
88 lines (83 loc) · 3.78 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from common.argument_parser import GeneralArgumentParser
from exps.run_algos import run_saute, run_simmer
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
if __name__ == "__main__":
exp_parser = GeneralArgumentParser()
args = exp_parser.parse_args()
current_experiment = args.experiment
### Single Pendulum
if current_experiment == 10:
from exps.single_pendulum.sac_cfg import cfg
run_saute(**cfg, smoketest=args.smoketest)
elif current_experiment == 11:
from exps.single_pendulum.ppo_cfg import cfg
run_saute(**cfg, smoketest=args.smoketest)
elif current_experiment == 12:
from exps.single_pendulum.trpo_cfg import cfg
run_saute(**cfg, smoketest=args.smoketest)
elif current_experiment == 13:
from exps.single_pendulum.ablation_cfg import cfg
run_saute(**cfg, smoketest=args.smoketest)
elif current_experiment == 14:
from exps.single_pendulum.pi_simmer_cfg import cfg
run_simmer(**cfg, smoketest=args.smoketest)
elif current_experiment == 15:
from exps.single_pendulum.q_simmer_cfg import cfg
run_simmer(**cfg, smoketest=args.smoketest)
elif current_experiment == 16:
from exps.single_pendulum.key_observation_cfg import cfg
run_saute(**cfg, smoketest=args.smoketest)
### Double Pendulum
elif current_experiment == 20:
from exps.double_pendulum.trpo_cfg import cfg
run_saute(**cfg, smoketest=args.smoketest)
elif current_experiment == 21:
from exps.double_pendulum.trpo_lag_cfg import cfg
run_saute(**cfg, smoketest=args.smoketest)
elif current_experiment == 22:
from exps.double_pendulum.naive_generalization import run
run(**cfg, smoketest=args.smoketest)
elif current_experiment == 23:
from exps.double_pendulum.zero_shot_generalization import run
run(**cfg, smoketest=args.smoketest)
elif current_experiment == 24:
from exps.double_pendulum.ablation_unsafe_cfg import cfg
run_saute(**cfg, smoketest=args.smoketest)
elif current_experiment == 25:
from exps.double_pendulum.ablation_components_cfg import cfg
run_saute(**cfg, smoketest=args.smoketest)
### Reacher
elif current_experiment == 30:
from exps.reacher.reacher_cfg import cfg
run_saute(**cfg, smoketest=args.smoketest)
### Safety Gym
elif current_experiment == 40:
from exps.safety_gym.sg_point_cfg import cfg
run_saute(**cfg, smoketest=args.smoketest)
### testing for minor bugs
elif current_experiment == -1: # experminetal feature
args.smoketest = -1
## single pendulum
from exps.single_pendulum.sac_cfg import cfg
run_saute(**cfg, smoketest=args.smoketest)
from exps.single_pendulum.ppo_cfg import cfg
run_saute(**cfg, smoketest=args.smoketest)
from exps.single_pendulum.trpo_cfg import cfg
run_saute(**cfg, smoketest=args.smoketest)
from exps.single_pendulum.pi_simmer_cfg import cfg
run_simmer(**cfg, smoketest=args.smoketest)
from exps.single_pendulum.q_simmer_cfg import cfg
run_simmer(**cfg, smoketest=args.smoketest)
## double pendulum
from exps.double_pendulum.trpo_cfg import cfg
run_saute(**cfg, smoketest=args.smoketest)
from exps.double_pendulum.trpo_lag_cfg import cfg
run_saute(**cfg, smoketest=args.smoketest)
# furtrher exps
from exps.reacher.reacher_cfg import cfg
run_saute(**cfg, smoketest=args.smoketest)
from exps.safety_gym.sg_point_cfg import cfg
run_saute(**cfg, smoketest=args.smoketest)
else:
raise NotImplementedError