-
Notifications
You must be signed in to change notification settings - Fork 50
/
Copy pathfig3.py
108 lines (88 loc) · 3.84 KB
/
fig3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import yaml
import scipy
import pickle
import argparse
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
def bootstrapped_return(x, y, stride, total_steps, confidence_level=0.95, to_bootstrap=True):
assert len(x) == len(y)
num_runs = len(x)
avg_ret = np.zeros(total_steps // stride)
steps = np.arange(stride, total_steps + stride, stride)
min_rets, max_rets = np.zeros(total_steps // stride), np.zeros(total_steps // stride)
boot_strapped_ret_low, boot_strapped_ret_high = np.zeros(total_steps // stride), np.zeros(total_steps // stride)
for i in tqdm(range(0, total_steps // stride)):
rets = []
for run in range(num_runs):
xa = x[run][:np.searchsorted(x[run], total_steps)+1]
ya = y[run][:xa.shape[0]]
rets.append(ya[np.logical_and(i*stride < xa, xa <= (i+1)*stride)].mean())
rets = np.array([rets])
avg_ret[i] = rets.mean()
min_rets[i], max_rets[i] = rets.min(), rets.max()
if to_bootstrap:
bos = scipy.stats.bootstrap(data=(rets[0, :],), statistic=np.mean, confidence_level=confidence_level)
boot_strapped_ret_low[i], boot_strapped_ret_high[i] = bos.confidence_interval.low, bos.confidence_interval.high
return steps, avg_ret, min_rets, max_rets, boot_strapped_ret_low, boot_strapped_ret_high
def get_param_performance(runs, data_dir=''):
per_param_setting_performance, per_param_setting_termination = [], []
for idx in runs:
file = data_dir + str(idx)
if file[0] == 'd': file = '../'+file
try:
with open(file, 'rb+') as f:
print(f)
data = pickle.load(f)
except:
with open(file+'.log', 'rb+') as f:
print(f)
data = pickle.load(f)
per_param_setting_performance.append(np.array(data['rets']))
per_param_setting_termination.append(np.array(data['termination_steps']))
print(data['termination_steps'][-1])
return per_param_setting_termination, per_param_setting_performance
def plot_for_one_cfg(cfg, runs, m, ts, color='C0', min_max=False):
data_dir = cfg['dir']
terminations, returns = get_param_performance(data_dir=data_dir, runs=runs)
x, y, min_y, max_y, boot_strapped_ret_low, boot_strapped_ret_high = \
bootstrapped_return(x=terminations, y=returns, stride=m, total_steps=ts)
plt.plot(x, y, '-', linewidth=1, color=color, label=cfg['label'])
plt.fill_between(x, boot_strapped_ret_low, boot_strapped_ret_high, alpha=0.3, color=color)
if min_max:
plt.fill_between(x, min_y, max_y, alpha=0.1, color=color)
def main():
env = 'sant'
cfg_file = f'../cfg/{env}/std.yml'
cfg_file1 = f'../cfg/{env}/ns.yml'
cfg_file2 = f'../cfg/{env}/cbp.yml'
cfg_file3 = f'../cfg/{env}/l2.yml'
cfg_files = [cfg_file, cfg_file1, cfg_file2, cfg_file3]
colors = ['C3', 'C1', 'C0', 'C4']
cfgs = []
for file in cfg_files:
if file == '': continue
cfgs.append(yaml.safe_load(open(file)))
if 'label' not in cfgs[-1].keys(): cfgs[-1]['label'] = ''
num_runs = 100
runs = [i + 0 for i in range(0, num_runs)]
m = 100 * 1000
ts = 20 * 1000 * 1000
fig, ax = plt.subplots()
yticks = [0, 2000, 4000, 5500]
for idx, cfg in enumerate(cfgs):
plot_for_one_cfg(cfg=cfg, runs=runs, m=m, ts=ts, color=colors[idx])
xticks = [0, 0.5 * ts, ts]
fontsize = 15
ax.set_xticks(xticks)
ax.set_xticklabels(['' for _ in xticks], fontsize=fontsize)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_yticks(yticks)
ax.set_yticklabels(['' for _ in yticks], fontsize=fontsize)
ax.set_ylim(yticks[0], yticks[-1])
ax.yaxis.grid()
plt.savefig('fig3.png', bbox_inches='tight', dpi=300)
plt.close()
if __name__ == "__main__":
main()