-
Notifications
You must be signed in to change notification settings - Fork 50
/
Copy pathfig4a.py
126 lines (104 loc) · 4.38 KB
/
fig4a.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import yaml
import scipy
import pickle
import argparse
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
def bootstrapped_return(x, y, stride, total_steps, confidence_level=0.95, to_bootstrap=True):
assert len(x) == len(y)
num_runs = len(x)
avg_ret = np.zeros(total_steps // stride)
steps = np.arange(stride, total_steps + stride, stride)
min_rets, max_rets = np.zeros(total_steps // stride), np.zeros(total_steps // stride)
boot_strapped_ret_low, boot_strapped_ret_high = np.zeros(total_steps // stride), np.zeros(total_steps // stride)
for i in tqdm(range(0, total_steps // stride)):
rets = []
for run in range(num_runs):
xa = x[run][:np.searchsorted(x[run], total_steps)+1]
ya = y[run][:xa.shape[0]]
rets.append(ya[np.logical_and(i*stride < xa, xa <= (i+1)*stride)].mean())
rets = np.array([rets])
avg_ret[i] = rets.mean()
min_rets[i], max_rets[i] = rets.min(), rets.max()
if to_bootstrap:
bos = scipy.stats.bootstrap(data=(rets[0, :],), statistic=np.mean, confidence_level=confidence_level)
boot_strapped_ret_low[i], boot_strapped_ret_high[i] = bos.confidence_interval.low, bos.confidence_interval.high
return steps, avg_ret, min_rets, max_rets, boot_strapped_ret_low, boot_strapped_ret_high
def get_param_performance(runs, data_dir=''):
per_param_setting_performance, per_param_setting_termination = [], []
for idx in runs:
file = data_dir + str(idx)
if file[0] == 'd': file = '../'+file
try:
with open(file, 'rb+') as f:
print(f)
data = pickle.load(f)
except:
with open(file+'.log', 'rb+') as f:
print(f)
data = pickle.load(f)
per_param_setting_performance.append(np.array(data['rets']))
per_param_setting_termination.append(np.array(data['termination_steps']))
print(data['termination_steps'][-1])
return per_param_setting_termination, per_param_setting_performance
def plot_for_one_cfg(cfg, runs, m, ts, color='C0', min_max=False):
data_dir = cfg['dir']
terminations, returns = get_param_performance(data_dir=data_dir, runs=runs)
x, y, min_y, max_y, boot_strapped_ret_low, boot_strapped_ret_high = \
bootstrapped_return(x=terminations, y=returns, stride=m, total_steps=ts)
plt.plot(x, y, '-', linewidth=1, color=color, label=cfg['label'])
plt.fill_between(x, boot_strapped_ret_low, boot_strapped_ret_high, alpha=0.3, color=color)
if min_max:
plt.fill_between(x, min_y, max_y, alpha=0.1, color=color)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--env', required=False, type=str, default='ant')
parser.add_argument('--all', required=False, type=bool, default=True)
args = parser.parse_args()
env = args.env
plot_all = args.all
cfg_file = f'../cfg/{env}/std.yml'
cfg_file1 = f'../cfg/{env}/cbp.yml'
cfg_file2, cfg_file3 = '', ''
if plot_all:
cfg_file2 = f'../cfg/{env}/ns.yml'
cfg_file3 = f'../cfg/{env}/l2.yml'
cfg_files = [cfg_file, cfg_file1, cfg_file2, cfg_file3]
colors = ['C3', 'C0', 'C1', 'C4']
cfgs = []
for file in cfg_files:
if file == '': continue
cfgs.append(yaml.safe_load(open(file)))
if 'label' not in cfgs[-1].keys(): cfgs[-1]['label'] = ''
# num_runs = 30
num_runs = 20
runs = [i + 0 for i in range(0, num_runs)]
m = 250 * 1000
ts = 100 * 1000 * 1000
fig, ax = plt.subplots()
if env == 'hopper':
yticks = [0, 1000, 2000, 2500]
m = 500 * 1000
if env == 'walker':
yticks = [0, 1000, 2000, 3000]
ts = 50 * 1000 * 1000
if env == 'ant':
yticks = [0, 2000, 4000, 5500]
ts = 50 * 1000 * 1000
for idx, cfg in enumerate(cfgs):
plot_for_one_cfg(cfg=cfg, runs=runs, m=m, ts=ts, color=colors[idx])
xticks = [0, 0.5 * ts, ts]
fontsize = 15
ax.set_xticks(xticks)
ax.set_xticklabels(['' for _ in xticks], fontsize=fontsize)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_yticks(yticks)
ax.set_yticklabels(['' for _ in yticks], fontsize=fontsize)
ax.set_ylim(yticks[0], yticks[-1])
ax.yaxis.grid()
plt.savefig('fig4a.png', bbox_inches='tight', dpi=250)
plt.close()
if __name__ == "__main__":
main()