-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathenv.py
117 lines (91 loc) · 4.98 KB
/
env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from torchvision.utils import save_image
import os
import torchvision.transforms.functional as TF
from dm_control import suite
from dm_control.suite.wrappers import pixels
import cv2
import numpy as np
import torch
import PIL.Image as PilImage
CONTROL_SUITE_ENVS = ['cartpole-balance', 'cartpole-swingup', 'reacher-easy','reacher-hard', 'finger-spin', 'cheetah-run', 'ball_in_cup-catch', 'walker-walk']
CONTROL_SUITE_ACTION_REPEATS = {'cartpole': 8, 'reacher': 4, 'finger': 2, 'cheetah': 4, 'ball_in_cup': 6, 'walker': 2}
# Preprocesses an observation inplace (from float32 Tensor [0, 255] to [-0.5, 0.5])
def preprocess_observation_(observation, bit_depth):
observation.div_(2 ** (8 - bit_depth)).floor_().div_(2 ** bit_depth).sub_(0.5) # Quantise to given bit depth and centre
observation.add_(torch.rand_like(observation).div_(2 ** bit_depth)) # Dequantise (to approx. match likelihood of PDF of continuous images vs. PMF of discrete images)
# Postprocess an observation for storage (from float32 numpy array [-0.5, 0.5] to uint8 numpy array [0, 255])
def postprocess_observation(observation, bit_depth):
return np.clip(np.floor((observation + 0.5) * 2 ** bit_depth) * 2 ** (8 - bit_depth), 0, 2 ** 8 - 1).astype(np.uint8)
def _images_to_observation(images, bit_depth):
#images = TF.to_tensor(np.asarray(images.copy()))
images = torch.tensor(images.copy().transpose(2, 0, 1), dtype=torch.float32) #put channel first and technical fix with .copy()
preprocess_observation_(images, bit_depth) # Quantise, centre and dequantise inplace
return images.unsqueeze(dim=0) # Add batch dimension
class ControlSuiteEnv():
def __init__(self, env_name, seed, max_episode_length, bit_depth):
domain, task = env_name.split('-')
self._env = suite.load(domain_name=domain, task_name=task, task_kwargs={'random': seed})
self._env = pixels.Wrapper(self._env)
self.max_episode_length = max_episode_length
self.action_repeat = CONTROL_SUITE_ACTION_REPEATS[domain]
self.bit_depth = bit_depth
def reset(self):
self.t = 0 # Reset internal timer
state = self._env.reset()
return _images_to_observation(self._env.physics.render(height=64, width=64, camera_id=0), self.bit_depth)
def step(self, action):
action = action.detach().numpy()
reward = 0
for k in range(self.action_repeat):
state = self._env.step(action)
reward += state.reward
self.t += 1 #increment internal timer
done = state.last() or self.t == self.max_episode_length
if done: break
observation = _images_to_observation(self._env.physics.render(height=64, width=64, camera_id=0), self.bit_depth)
return observation, reward, done
def get_original_frame(self):
observation = _images_to_observation(self._env.physics.render(camera_id=0), self.bit_depth)
return observation
def render(self):
cv2.imshow('screen', self._env.physics.render(camera_id=0)[:, :, ::-1])
cv2.waitKey(1)
def close(self):
cv2.destroyAllWindows()
self._env.close()
@property
def observation_size(self):
return (3, 64, 64)
@property
def action_size(self):
return self._env.action_spec().shape[0]
@property
def action_range(self):
return float(self._env.action_spec().minimum[0]), float(self._env.action_spec().maximum[0])
# Sample an action randomly from a uniform distribution over all valid actions (return torch tensor)
def sample_random_action(self):
spec = self._env.action_spec()
return torch.from_numpy(np.random.uniform(spec.minimum, spec.maximum, spec.shape))
class EnvBatcher():
def __init__(self, env_class, env_args, env_kwargs, n):
self.n = n
self.envs = [env_class(*env_args, **env_kwargs) for _ in range(n)]
self.dones = [True] * n
self.action_repeat = self.envs[0].action_repeat
# Resets every environment and returns observation
def reset(self):
observations = [env.reset() for env in self.envs]
self.dones = [False] * self.n
return torch.cat(observations)
# Steps/resets every environment and returns (observation, reward, done)
def step(self, actions):
done_mask = torch.nonzero(torch.tensor(self.dones))[:, 0] # Done mask to blank out observations and zero rewards for previously terminated environments
observations, rewards, dones = zip(*[env.step(action) for env, action in zip(self.envs, actions)])
dones = [d or prev_d for d, prev_d in zip(dones, self.dones)] # Env should remain terminated if previously terminated
self.dones = dones
observations, rewards, dones = torch.cat(observations), torch.tensor(rewards, dtype=torch.float32), torch.tensor(dones, dtype=torch.uint8)
observations[done_mask] = 0
rewards[done_mask] = 0
return observations, rewards, dones
def close(self):
[env.close() for env in self.envs]