forked from thu-ml/tianshou
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmujoco_env.py
117 lines (99 loc) · 3.77 KB
/
mujoco_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import logging
import pickle
from gymnasium import Env
from tianshou.env import BaseVectorEnv, VectorEnvNormObs
from tianshou.highlevel.env import (
ContinuousEnvironments,
EnvFactoryRegistered,
EnvMode,
EnvPoolFactory,
VectorEnvType,
)
from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent
from tianshou.highlevel.world import World
envpool_is_available = True
try:
import envpool
except ImportError:
envpool_is_available = False
envpool = None
log = logging.getLogger(__name__)
def make_mujoco_env(
task: str,
seed: int,
num_train_envs: int,
num_test_envs: int,
obs_norm: bool,
) -> tuple[Env, BaseVectorEnv, BaseVectorEnv]:
"""Wrapper function for Mujoco env.
If EnvPool is installed, it will automatically switch to EnvPool's Mujoco env.
:return: a tuple of (single env, training envs, test envs).
"""
envs = MujocoEnvFactory(task, seed, seed + num_train_envs, obs_norm=obs_norm).create_envs(
num_train_envs,
num_test_envs,
)
return envs.env, envs.train_envs, envs.test_envs
class MujocoEnvObsRmsPersistence(Persistence):
FILENAME = "env_obs_rms.pkl"
def persist(self, event: PersistEvent, world: World) -> None:
if event != PersistEvent.PERSIST_POLICY:
return # type: ignore[unreachable] # since PersistEvent has only one member, mypy infers that line is unreachable
obs_rms = world.envs.train_envs.get_obs_rms()
path = world.persist_path(self.FILENAME)
log.info(f"Saving environment obs_rms value to {path}")
with open(path, "wb") as f:
pickle.dump(obs_rms, f)
def restore(self, event: RestoreEvent, world: World) -> None:
if event != RestoreEvent.RESTORE_POLICY:
return # type: ignore[unreachable]
path = world.restore_path(self.FILENAME)
log.info(f"Restoring environment obs_rms value from {path}")
with open(path, "rb") as f:
obs_rms = pickle.load(f)
world.envs.train_envs.set_obs_rms(obs_rms)
world.envs.test_envs.set_obs_rms(obs_rms)
if world.envs.watch_env is not None:
world.envs.watch_env.set_obs_rms(obs_rms)
class MujocoEnvFactory(EnvFactoryRegistered):
def __init__(
self,
task: str,
train_seed: int,
test_seed: int,
obs_norm: bool = True,
venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM_AUTO,
) -> None:
super().__init__(
task=task,
train_seed=train_seed,
test_seed=test_seed,
venv_type=venv_type,
envpool_factory=EnvPoolFactory() if envpool_is_available else None,
)
self.obs_norm = obs_norm
def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv:
"""Create vectorized environments.
:param num_envs: the number of environments
:param mode: the mode for which to create
:return: the vectorized environments
"""
env = super().create_venv(num_envs, mode)
# obs norm wrapper
if self.obs_norm:
env = VectorEnvNormObs(env, update_obs_rms=mode == EnvMode.TRAIN)
return env
def create_envs(
self,
num_training_envs: int,
num_test_envs: int,
create_watch_env: bool = False,
) -> ContinuousEnvironments:
envs = super().create_envs(num_training_envs, num_test_envs, create_watch_env)
assert isinstance(envs, ContinuousEnvironments)
if self.obs_norm:
envs.test_envs.set_obs_rms(envs.train_envs.get_obs_rms())
if envs.watch_env is not None:
envs.watch_env.set_obs_rms(envs.train_envs.get_obs_rms())
envs.set_persistence(MujocoEnvObsRmsPersistence())
return envs