-
Notifications
You must be signed in to change notification settings - Fork 6
/
StateBuffer.py
93 lines (79 loc) · 3.04 KB
/
StateBuffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import datetime
import gym
import numpy as np
import torch
from tensorboardX import SummaryWriter
from ImageWrapper import ImageWrapper
class StateBuffer:
def __init__(self, capacity, initial_state):
self.get_state = self._get_all
self.capacity = capacity
self.c = initial_state.shape[0]
self.h = initial_state.shape[1]
self.w = initial_state.shape[2]
# self.h = initial_state.shape[0]
# self.w = initial_state.shape[1]
self.buffer = []
self.position = 0
# Fill the buffer with zeros
for i in range(self.capacity):
self.buffer.append(None)
self.buffer[self.position] = initial_state
self.position = (self.position + 1) % self.capacity
# # Insert the first state
# self.buffer[self.position] = initial_state
# self.position = (self.position + 1) % self.capacity
def push(self, state_):
if len(self.buffer) < self.capacity:
self.buffer.append(None)
self.buffer[self.position] = state_
self.position = (self.position + 1) % self.capacity
def get_tensor(self):
state_ = None
pos_i = 0
for i in range(self.capacity):
pos = (self.position + i) % self.capacity
if state_ is None:
state_ = torch.tensor(self.buffer[pos], dtype=torch.float32).unsqueeze(0)
else:
new_state = torch.tensor(self.buffer[pos], dtype=torch.float32).unsqueeze(0)
state_ = torch.cat((state_, new_state), dim=0)
pos_i += 1
return state_
def _get_all(self):
state_ = None
pos_i = 0
for i in range(self.capacity):
pos = (self.position + i) % self.capacity
if state_ is None:
state_ = self.buffer[pos]
else:
state_ = np.concatenate((state_, self.buffer[pos]))
pos_i += 1
return state_
def _get_mean(self):
state_ = None
for i in range(self.capacity):
pos = (self.position + i) % self.capacity
if state_ is None:
state_ = self.buffer[pos]
else:
state_ += self.buffer[pos]
return state_ / self.capacity
def _get_diff(self):
return self.buffer[1] - self.buffer[0]
def __len__(self):
return len(self.buffer)
if __name__ == '__main__':
folder = '{}_StateBuffer_{}/'.format(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), "Pendulum-v0")
writer = SummaryWriter(log_dir='runs/' + folder)
env = ImageWrapper(512, gym.make("Pendulum-v0"))
old = env.reset()
state_buffer = StateBuffer(3, old)
for i in range(199):
state = state_buffer.get_state()
# writer.add_images('episode', state_buffer.get_tensor(), i)
writer.add_images('episode', state_buffer.get_tensor(), i)
next_state, reward, done, _ = env.step(env.action_space.sample())
state_buffer.push(next_state)
env.close()