From 6eb35003cd7544095afb533d6e947e00895b0b56 Mon Sep 17 00:00:00 2001 From: bordeauxred <2robert.mueller@gmail.com> Date: Thu, 28 Mar 2024 18:02:31 +0100 Subject: [PATCH] Feat/refactor collector (#1063) Closes: #1058 ### Api Extensions - Batch received two new methods: `to_dict` and `to_list_of_dicts`. #1063 - `Collector`s can now be closed, and their reset is more granular. #1063 - Trainers can control whether collectors should be reset prior to training. #1063 - Convenience constructor for `CollectStats` called `with_autogenerated_stats`. #1063 ### Internal Improvements - `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063 - Introduced a first iteration of a naming convention for vars in `Collector`s. #1063 - Generally improved readability of Collector code and associated tests (still quite some way to go). #1063 - Improved typing for `exploration_noise` and within Collector. #1063 ### Breaking Changes - Removed `.data` attribute from `Collector` and its child classes. #1063 - Collectors no longer reset the environment on initialization. Instead, the user might have to call `reset` expicitly or pass `reset_before_collect=True` . #1063 - VectorEnvs now return an array of info-dicts on reset instead of a list. #1063 - Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`. Can be considered a bugfix. #1063 --------- Co-authored-by: Michael Panchenko --- CHANGELOG.md | 23 + docs/02_notebooks/L0_overview.ipynb | 2 +- docs/02_notebooks/L5_Collector.ipynb | 5 +- pyproject.toml | 1 + test/base/env.py | 24 +- test/base/test_buffer.py | 12 +- test/base/test_collector.py | 513 +++++++----- test/base/test_env.py | 22 +- test/base/test_env_finite.py | 19 +- test/continuous/test_redq.py | 1 + test/continuous/test_td3.py | 1 + test/discrete/test_a2c_with_il.py | 2 + test/discrete/test_bdq.py | 2 +- test/discrete/test_c51.py | 2 +- test/discrete/test_dqn.py | 2 +- test/discrete/test_drqn.py | 2 +- test/discrete/test_fqf.py | 2 +- test/discrete/test_iqn.py | 2 +- test/discrete/test_qrdqn.py | 2 +- test/discrete/test_rainbow.py | 2 +- test/modelbased/test_dqn_icm.py | 2 +- test/modelbased/test_psrl.py | 3 +- test/offline/gather_cartpole_data.py | 5 +- test/offline/test_discrete_bcq.py | 1 + test/pettingzoo/pistonball.py | 6 +- test/pettingzoo/pistonball_continuous.py | 2 +- test/pettingzoo/tic_tac_toe.py | 6 +- tianshou/data/batch.py | 34 +- tianshou/data/buffer/manager.py | 2 +- tianshou/data/collector.py | 923 ++++++++++++++-------- tianshou/data/utils/converter.py | 1 + tianshou/env/venv_wrappers.py | 16 +- tianshou/env/venvs.py | 26 +- tianshou/highlevel/agent.py | 11 + tianshou/highlevel/experiment.py | 2 +- tianshou/policy/base.py | 12 +- tianshou/policy/modelbased/icm.py | 10 +- tianshou/policy/modelfree/bdq.py | 9 +- tianshou/policy/modelfree/ddpg.py | 9 +- tianshou/policy/modelfree/discrete_sac.py | 11 +- tianshou/policy/modelfree/dqn.py | 9 +- tianshou/policy/multiagent/mapolicy.py | 22 +- tianshou/trainer/base.py | 28 +- tianshou/trainer/utils.py | 3 +- 44 files changed, 1152 insertions(+), 642 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e0ac65c5..5a37acb17 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,27 @@ # Changelog +## Release 1.1.0 + +### Api Extensions +- Batch received two new methods: `to_dict` and `to_list_of_dicts`. #1063 +- `Collector`s can now be closed, and their reset is more granular. #1063 +- Trainers can control whether collectors should be reset prior to training. #1063 +- Convenience constructor for `CollectStats` called `with_autogenerated_stats`. #1063 + +### Internal Improvements +- `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063 +- Introduced a first iteration of a naming convention for vars in `Collector`s. #1063 +- Generally improved readability of Collector code and associated tests (still quite some way to go). #1063 +- Improved typing for `exploration_noise` and within Collector. #1063 + +### Breaking Changes + +- Removed `.data` attribute from `Collector` and its child classes. #1063 +- Collectors no longer reset the environment on initialization. Instead, the user might have to call `reset` +expicitly or pass `reset_before_collect=True` . #1063 +- VectorEnvs now return an array of info-dicts on reset instead of a list. #1063 +- Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`. Can be considered a bugfix. #1063 + + Started after v1.0.0 diff --git a/docs/02_notebooks/L0_overview.ipynb b/docs/02_notebooks/L0_overview.ipynb index ac0514b82..37cba0be5 100644 --- a/docs/02_notebooks/L0_overview.ipynb +++ b/docs/02_notebooks/L0_overview.ipynb @@ -164,7 +164,7 @@ "source": [ "# Let's watch its performance!\n", "policy.eval()\n", - "eval_result = test_collector.collect(n_episode=1, render=False)\n", + "eval_result = test_collector.collect(n_episode=3, render=False)\n", "print(f\"Final reward: {eval_result.returns.mean()}, length: {eval_result.lens.mean()}\")" ] }, diff --git a/docs/02_notebooks/L5_Collector.ipynb b/docs/02_notebooks/L5_Collector.ipynb index 1053e15a5..3e91e0f43 100644 --- a/docs/02_notebooks/L5_Collector.ipynb +++ b/docs/02_notebooks/L5_Collector.ipynb @@ -119,7 +119,7 @@ }, "outputs": [], "source": [ - "collect_result = test_collector.collect(n_episode=9)\n", + "collect_result = test_collector.collect(reset_before_collect=True, n_episode=9)\n", "\n", "collect_result.pprint_asdict()" ] @@ -146,8 +146,7 @@ "outputs": [], "source": [ "# Reset the collector\n", - "test_collector.reset()\n", - "collect_result = test_collector.collect(n_episode=9, random=True)\n", + "collect_result = test_collector.collect(reset_before_collect=True, n_episode=9, random=True)\n", "\n", "collect_result.pprint_asdict()" ] diff --git a/pyproject.toml b/pyproject.toml index d4679f949..7a795004a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -166,6 +166,7 @@ ignore = [ "RET505", "D106", # undocumented public nested class "D205", # blank line after summary (prevents summary-only docstrings, which makes no sense) + "PLW2901", # overwrite vars in loop ] unfixable = [ "F841", # unused variable. ruff keeps the call, but mostly we want to get rid of it all diff --git a/test/base/env.py b/test/base/env.py index 8a2de26cc..c05c98718 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -9,13 +9,24 @@ from gymnasium.spaces import Box, Dict, Discrete, MultiDiscrete, Space, Tuple -class MyTestEnv(gym.Env): - """A task for "going right". The task is to go right ``size`` steps.""" +class MoveToRightEnv(gym.Env): + """A task for "going right". The task is to go right ``size`` steps. + + The observation is the current index, and the action is to go left or right. + Action 0 is to go left, and action 1 is to go right. + Taking action 0 at index 0 will keep the index at 0. + Arriving at index ``size`` means the task is done. + In the current implementation, stepping after the task is done is possible, which will + lead the index to be larger than ``size``. + + Index 0 is the starting point. If reset is called with default options, the index will + be reset to 0. + """ def __init__( self, size: int, - sleep: int = 0, + sleep: float = 0.0, dict_state: bool = False, recurse_state: bool = False, ma_rew: int = 0, @@ -74,8 +85,13 @@ def __init__( def reset( self, seed: int | None = None, + # TODO: passing a dict here doesn't make any sense options: dict[str, Any] | None = None, ) -> tuple[dict[str, Any] | np.ndarray, dict]: + """:param seed: + :param options: the start index is provided in options["state"] + :return: + """ if options is None: options = {"state": 0} super().reset(seed=seed) @@ -188,7 +204,7 @@ def step( return self._encode_obs(), 1.0, False, False, {} -class MyGoalEnv(MyTestEnv): +class MyGoalEnv(MoveToRightEnv): def __init__(self, *args: Any, **kwargs: Any) -> None: assert ( kwargs.get("dict_state", 0) + kwargs.get("recurse_state", 0) == 0 diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 99154bbdf..0806a750f 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -22,13 +22,13 @@ from tianshou.data.utils.converter import to_hdf5 if __name__ == "__main__": - from env import MyGoalEnv, MyTestEnv + from env import MoveToRightEnv, MyGoalEnv else: # pytest - from test.base.env import MyGoalEnv, MyTestEnv + from test.base.env import MoveToRightEnv, MyGoalEnv def test_replaybuffer(size=10, bufsize=20) -> None: - env = MyTestEnv(size) + env = MoveToRightEnv(size) buf = ReplayBuffer(bufsize) buf.update(buf) assert str(buf) == buf.__class__.__name__ + "()" @@ -209,7 +209,7 @@ def test_ignore_obs_next(size=10) -> None: def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3) -> None: - env = MyTestEnv(size) + env = MoveToRightEnv(size) buf = ReplayBuffer(bufsize, stack_num=stack_num) buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True) buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True) @@ -280,7 +280,7 @@ def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3) -> None: def test_priortized_replaybuffer(size=32, bufsize=15) -> None: - env = MyTestEnv(size) + env = MoveToRightEnv(size) buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5) buf2 = PrioritizedVectorReplayBuffer(bufsize, buffer_num=3, alpha=0.5, beta=0.5) obs, info = env.reset() @@ -1028,7 +1028,7 @@ def test_multibuf_stack() -> None: bufsize = 9 stack_num = 4 cached_num = 3 - env = MyTestEnv(size) + env = MoveToRightEnv(size) # test if CachedReplayBuffer can handle stack_num + ignore_obs_next buf4 = CachedReplayBuffer( ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True), diff --git a/test/base/test_collector.py b/test/base/test_collector.py index f7a24a86e..6bc1703f6 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -2,7 +2,6 @@ import numpy as np import pytest import tqdm -from torch.utils.tensorboard import SummaryWriter from tianshou.data import ( AsyncCollector, @@ -22,12 +21,12 @@ envpool = None if __name__ == "__main__": - from env import MyTestEnv, NXEnv + from env import MoveToRightEnv, NXEnv else: # pytest - from test.base.env import MyTestEnv, NXEnv + from test.base.env import MoveToRightEnv, NXEnv -class MyPolicy(BasePolicy): +class MaxActionPolicy(BasePolicy): def __init__( self, action_space: gym.spaces.Space | None = None, @@ -35,7 +34,9 @@ def __init__( need_state=True, action_shape=None, ) -> None: - """Mock policy for testing. + """Mock policy for testing, will always return an array of ones of the shape of the action space. + Note that this doesn't make much sense for discrete action space (the output is then intepreted as + logits, meaning all actions would be equally likely). :param action_space: the action space of the environment. If None, a dummy Box space will be used. :param bool dict_state: if the observation of the environment is a dict @@ -63,215 +64,290 @@ def learn(self): pass -class Logger: - def __init__(self, writer) -> None: - self.cnt = 0 - self.writer = writer - - def preprocess_fn(self, **kwargs): - # modify info before adding into the buffer, and recorded into tfb - # if obs && env_id exist -> reset - # if obs_next/rew/done/info/env_id exist -> normal step - if "rew" in kwargs: - info = kwargs["info"] - info.rew = kwargs["rew"] - if "key" in info: - self.writer.add_scalar("key", np.mean(info.key), global_step=self.cnt) - self.cnt += 1 - return Batch(info=info) - return Batch() - - @staticmethod - def single_preprocess_fn(**kwargs): - # same as above, without tfb - if "rew" in kwargs: - info = kwargs["info"] - info.rew = kwargs["rew"] - return Batch(info=info) - return Batch() - - -@pytest.mark.parametrize("gym_reset_kwargs", [None, {}]) -def test_collector(gym_reset_kwargs) -> None: - writer = SummaryWriter("log/collector") - logger = Logger(writer) - env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]] - - venv = SubprocVectorEnv(env_fns) - dum = DummyVectorEnv(env_fns) - policy = MyPolicy() - env = env_fns[0]() - c0 = Collector( +def test_collector() -> None: + env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0) for i in [2, 3, 4, 5]] + + subproc_venv_4_envs = SubprocVectorEnv(env_fns) + dummy_venv_4_envs = DummyVectorEnv(env_fns) + policy = MaxActionPolicy() + single_env = env_fns[0]() + c_single_env = Collector( policy, - env, + single_env, ReplayBuffer(size=100), - logger.preprocess_fn, ) - c0.collect(n_step=3, gym_reset_kwargs=gym_reset_kwargs) - assert len(c0.buffer) == 3 - assert np.allclose(c0.buffer.obs[:4, 0], [0, 1, 0, 0]) - assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1]) + c_single_env.reset() + c_single_env.collect(n_step=3) + assert len(c_single_env.buffer) == 3 + # TODO: direct attr access is an arcane way of using the buffer, it should be never done + # The placeholders for entries are all zeros, so buffer.obs is an array filled with 3 + # observations, and 97 zeros. + # However, buffer[:] will have all attributes with length three... The non-filled entries are removed there + + # See above. For the single env, we start with obs=0, obs_next=1. + # We move to obs=1, obs_next=2, + # then the env is reset and we move to obs=0 + # Making one more step results in obs_next=1 + # The final 0 in the buffer.obs is because the buffer is initialized with zeros and the direct attr access + assert np.allclose(c_single_env.buffer.obs[:4, 0], [0, 1, 0, 0]) + assert np.allclose(c_single_env.buffer[:].obs_next[..., 0], [1, 2, 1]) keys = np.zeros(100) keys[:3] = 1 - assert np.allclose(c0.buffer.info["key"], keys) - for e in c0.buffer.info["env"][:3]: - assert isinstance(e, MyTestEnv) - assert np.allclose(c0.buffer.info["env_id"], 0) + assert np.allclose(c_single_env.buffer.info["key"], keys) + for e in c_single_env.buffer.info["env"][:3]: + assert isinstance(e, MoveToRightEnv) + assert np.allclose(c_single_env.buffer.info["env_id"], 0) rews = np.zeros(100) rews[:3] = [0, 1, 0] - assert np.allclose(c0.buffer.info["rew"], rews) - c0.collect(n_episode=3, gym_reset_kwargs=gym_reset_kwargs) - assert len(c0.buffer) == 8 - assert np.allclose(c0.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0]) - assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2]) - assert np.allclose(c0.buffer.info["key"][:8], 1) - for e in c0.buffer.info["env"][:8]: - assert isinstance(e, MyTestEnv) - assert np.allclose(c0.buffer.info["env_id"][:8], 0) - assert np.allclose(c0.buffer.info["rew"][:8], [0, 1, 0, 1, 0, 1, 0, 1]) - c0.collect(n_step=3, random=True, gym_reset_kwargs=gym_reset_kwargs) - - c1 = Collector( + assert np.allclose(c_single_env.buffer.rew, rews) + # At this point, the buffer contains obs 0 -> 1 -> 0 + + # At start we have 3 entries in the buffer + # We collect 3 episodes, in addition to the transitions we have collected before + # 0 -> 1 -> 0 -> 0 (reset at collection start) -> 1 -> done (0) -> 1 -> done(0) + # obs_next: 1 -> 2 -> 1 -> 1 (reset at collection start) -> 2 -> 1 -> 2 -> 1 -> 2 + # In total, we will have 3 + 6 = 9 entries in the buffer + c_single_env.collect(n_episode=3) + assert len(c_single_env.buffer) == 8 + assert np.allclose(c_single_env.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0]) + assert np.allclose(c_single_env.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2]) + assert np.allclose(c_single_env.buffer.info["key"][:8], 1) + for e in c_single_env.buffer.info["env"][:8]: + assert isinstance(e, MoveToRightEnv) + assert np.allclose(c_single_env.buffer.info["env_id"][:8], 0) + assert np.allclose(c_single_env.buffer.rew[:8], [0, 1, 0, 1, 0, 1, 0, 1]) + c_single_env.collect(n_step=3, random=True) + + c_subproc_venv_4_envs = Collector( policy, - venv, + subproc_venv_4_envs, VectorReplayBuffer(total_size=100, buffer_num=4), - logger.preprocess_fn, ) - c1.collect(n_step=8, gym_reset_kwargs=gym_reset_kwargs) + c_subproc_venv_4_envs.reset() + + # Collect some steps + c_subproc_venv_4_envs.collect(n_step=8) obs = np.zeros(100) valid_indices = [0, 1, 25, 26, 50, 51, 75, 76] obs[valid_indices] = [0, 1, 0, 1, 0, 1, 0, 1] - assert np.allclose(c1.buffer.obs[:, 0], obs) - assert np.allclose(c1.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2]) + assert np.allclose(c_subproc_venv_4_envs.buffer.obs[:, 0], obs) + assert np.allclose(c_subproc_venv_4_envs.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2]) keys = np.zeros(100) keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1] - assert np.allclose(c1.buffer.info["key"], keys) - for e in c1.buffer.info["env"][valid_indices]: - assert isinstance(e, MyTestEnv) + assert np.allclose(c_subproc_venv_4_envs.buffer.info["key"], keys) + for e in c_subproc_venv_4_envs.buffer.info["env"][valid_indices]: + assert isinstance(e, MoveToRightEnv) env_ids = np.zeros(100) env_ids[valid_indices] = [0, 0, 1, 1, 2, 2, 3, 3] - assert np.allclose(c1.buffer.info["env_id"], env_ids) + assert np.allclose(c_subproc_venv_4_envs.buffer.info["env_id"], env_ids) rews = np.zeros(100) rews[valid_indices] = [0, 1, 0, 0, 0, 0, 0, 0] - assert np.allclose(c1.buffer.info["rew"], rews) - c1.collect(n_episode=4, gym_reset_kwargs=gym_reset_kwargs) - assert len(c1.buffer) == 16 + assert np.allclose(c_subproc_venv_4_envs.buffer.rew, rews) + + # we previously collected 8 steps, 2 from each env, now we collect 4 episodes + # each env will contribute an episode, which will be of lens 2 (first env was reset), 1, 2, 3 + # So we get 8 + 2+1+2+3 = 16 steps + c_subproc_venv_4_envs.collect(n_episode=4) + assert len(c_subproc_venv_4_envs.buffer) == 16 + valid_indices = [2, 3, 27, 52, 53, 77, 78, 79] - obs[[2, 3, 27, 52, 53, 77, 78, 79]] = [0, 1, 2, 2, 3, 2, 3, 4] - assert np.allclose(c1.buffer.obs[:, 0], obs) + obs[valid_indices] = [0, 1, 2, 2, 3, 2, 3, 4] + assert np.allclose(c_subproc_venv_4_envs.buffer.obs[:, 0], obs) assert np.allclose( - c1.buffer[:].obs_next[..., 0], + c_subproc_venv_4_envs.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5], ) keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1] - assert np.allclose(c1.buffer.info["key"], keys) - for e in c1.buffer.info["env"][valid_indices]: - assert isinstance(e, MyTestEnv) + assert np.allclose(c_subproc_venv_4_envs.buffer.info["key"], keys) + for e in c_subproc_venv_4_envs.buffer.info["env"][valid_indices]: + assert isinstance(e, MoveToRightEnv) env_ids[valid_indices] = [0, 0, 1, 2, 2, 3, 3, 3] - assert np.allclose(c1.buffer.info["env_id"], env_ids) + assert np.allclose(c_subproc_venv_4_envs.buffer.info["env_id"], env_ids) rews[valid_indices] = [0, 1, 1, 0, 1, 0, 0, 1] - assert np.allclose(c1.buffer.info["rew"], rews) - c1.collect(n_episode=4, random=True, gym_reset_kwargs=gym_reset_kwargs) + assert np.allclose(c_subproc_venv_4_envs.buffer.rew, rews) + c_subproc_venv_4_envs.collect(n_episode=4, random=True) - c2 = Collector( + c_dummy_venv_4_envs = Collector( policy, - dum, + dummy_venv_4_envs, VectorReplayBuffer(total_size=100, buffer_num=4), - logger.preprocess_fn, ) - c2.collect(n_episode=7, gym_reset_kwargs=gym_reset_kwargs) + c_dummy_venv_4_envs.reset() + c_dummy_venv_4_envs.collect(n_episode=7) obs1 = obs.copy() obs1[[4, 5, 28, 29, 30]] = [0, 1, 0, 1, 2] obs2 = obs.copy() obs2[[28, 29, 30, 54, 55, 56, 57]] = [0, 1, 2, 0, 1, 2, 3] - c2obs = c2.buffer.obs[:, 0] + c2obs = c_dummy_venv_4_envs.buffer.obs[:, 0] assert np.all(c2obs == obs1) or np.all(c2obs == obs2) - c2.reset_env(gym_reset_kwargs=gym_reset_kwargs) - c2.reset_buffer() - assert c2.collect(n_episode=8, gym_reset_kwargs=gym_reset_kwargs).n_collected_episodes == 8 + c_dummy_venv_4_envs.reset_env() + c_dummy_venv_4_envs.reset_buffer() + assert c_dummy_venv_4_envs.collect(n_episode=8).n_collected_episodes == 8 valid_indices = [4, 5, 28, 29, 30, 54, 55, 56, 57] obs[valid_indices] = [0, 1, 0, 1, 2, 0, 1, 2, 3] - assert np.all(c2.buffer.obs[:, 0] == obs) + assert np.all(c_dummy_venv_4_envs.buffer.obs[:, 0] == obs) keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1, 1] - assert np.allclose(c2.buffer.info["key"], keys) - for e in c2.buffer.info["env"][valid_indices]: - assert isinstance(e, MyTestEnv) + assert np.allclose(c_dummy_venv_4_envs.buffer.info["key"], keys) + for e in c_dummy_venv_4_envs.buffer.info["env"][valid_indices]: + assert isinstance(e, MoveToRightEnv) env_ids[valid_indices] = [0, 0, 1, 1, 1, 2, 2, 2, 2] - assert np.allclose(c2.buffer.info["env_id"], env_ids) + assert np.allclose(c_dummy_venv_4_envs.buffer.info["env_id"], env_ids) rews[valid_indices] = [0, 1, 0, 0, 1, 0, 0, 0, 1] - assert np.allclose(c2.buffer.info["rew"], rews) - c2.collect(n_episode=4, random=True, gym_reset_kwargs=gym_reset_kwargs) + assert np.allclose(c_dummy_venv_4_envs.buffer.rew, rews) + c_dummy_venv_4_envs.collect(n_episode=4, random=True) # test corner case with pytest.raises(TypeError): - Collector(policy, dum, ReplayBuffer(10)) + Collector(policy, dummy_venv_4_envs, ReplayBuffer(10)) with pytest.raises(TypeError): - Collector(policy, dum, PrioritizedReplayBuffer(10, 0.5, 0.5)) + Collector(policy, dummy_venv_4_envs, PrioritizedReplayBuffer(10, 0.5, 0.5)) with pytest.raises(TypeError): - c2.collect() + c_dummy_venv_4_envs.collect() # test NXEnv for obs_type in ["array", "object"]: envs = SubprocVectorEnv([lambda i=x, t=obs_type: NXEnv(i, t) for x in [5, 10, 15, 20]]) - c3 = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4)) - c3.collect(n_step=6, gym_reset_kwargs=gym_reset_kwargs) - assert c3.buffer.obs.dtype == object + c_suproc_new = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4)) + c_suproc_new.reset() + c_suproc_new.collect(n_step=6) + assert c_suproc_new.buffer.obs.dtype == object -@pytest.mark.parametrize("gym_reset_kwargs", [None, {}]) -def test_collector_with_async(gym_reset_kwargs) -> None: +@pytest.fixture() +def get_AsyncCollector(): env_lens = [2, 3, 4, 5] - writer = SummaryWriter("log/async_collector") - logger = Logger(writer) - env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0.001, random_sleep=True) for i in env_lens] + env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0.001, random_sleep=True) for i in env_lens] venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1) - policy = MyPolicy() + policy = MaxActionPolicy() bufsize = 60 c1 = AsyncCollector( policy, venv, VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4), - logger.preprocess_fn, ) - ptr = [0, 0, 0, 0] - for n_episode in tqdm.trange(1, 30, desc="test async n_episode"): - result = c1.collect(n_episode=n_episode, gym_reset_kwargs=gym_reset_kwargs) - assert result.n_collected_episodes >= n_episode - # check buffer data, obs and obs_next, env_id - for i, count in enumerate(np.bincount(result.lens, minlength=6)[2:]): - env_len = i + 2 - total = env_len * count - indices = np.arange(ptr[i], ptr[i] + total) % bufsize - ptr[i] = (ptr[i] + total) % bufsize - seq = np.arange(env_len) - buf = c1.buffer.buffers[i] - assert np.all(buf.info.env_id[indices] == i) - assert np.all(buf.obs[indices].reshape(count, env_len) == seq) - assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1) - # test async n_step, for now the buffer should be full of data - for n_step in tqdm.trange(1, 15, desc="test async n_step"): - result = c1.collect(n_step=n_step, gym_reset_kwargs=gym_reset_kwargs) - assert result.n_collected_steps >= n_step - for i in range(4): - env_len = i + 2 - seq = np.arange(env_len) - buf = c1.buffer.buffers[i] - assert np.all(buf.info.env_id == i) - assert np.all(buf.obs.reshape(-1, env_len) == seq) - assert np.all(buf.obs_next.reshape(-1, env_len) == seq + 1) - with pytest.raises(TypeError): - c1.collect() + c1.reset() + return c1, env_lens + + +class TestAsyncCollector: + def test_collect_without_argument_gives_error(self, get_AsyncCollector): + c1, env_lens = get_AsyncCollector + with pytest.raises(TypeError): + c1.collect() + + def test_collect_one_episode_async(self, get_AsyncCollector): + c1, env_lens = get_AsyncCollector + result = c1.collect(n_episode=1) + assert result.n_collected_episodes >= 1 + + def test_enough_episodes_two_collection_cycles_n_episode_without_reset( + self, + get_AsyncCollector, + ): + c1, env_lens = get_AsyncCollector + n_episode = 2 + result_c1 = c1.collect(n_episode=n_episode, reset_before_collect=False) + assert result_c1.n_collected_episodes >= n_episode + result_c2 = c1.collect(n_episode=n_episode, reset_before_collect=False) + assert result_c2.n_collected_episodes >= n_episode + + def test_enough_episodes_two_collection_cycles_n_episode_with_reset(self, get_AsyncCollector): + c1, env_lens = get_AsyncCollector + n_episode = 2 + result_c1 = c1.collect(n_episode=n_episode, reset_before_collect=True) + assert result_c1.n_collected_episodes >= n_episode + result_c2 = c1.collect(n_episode=n_episode, reset_before_collect=True) + assert result_c2.n_collected_episodes >= n_episode + + def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collection_cycles_n_episode( + self, + get_AsyncCollector, + ): + c1, env_lens = get_AsyncCollector + ptr = [0, 0, 0, 0] + bufsize = 60 + for n_episode in tqdm.trange(1, 30, desc="test async n_episode"): + result = c1.collect(n_episode=n_episode) + assert result.n_collected_episodes >= n_episode + # check buffer data, obs and obs_next, env_id + for i, count in enumerate(np.bincount(result.lens, minlength=6)[2:]): + env_len = i + 2 + total = env_len * count + indices = np.arange(ptr[i], ptr[i] + total) % bufsize + ptr[i] = (ptr[i] + total) % bufsize + seq = np.arange(env_len) + buf = c1.buffer.buffers[i] + assert np.all(buf.info.env_id[indices] == i) + assert np.all(buf.obs[indices].reshape(count, env_len) == seq) + assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1) + + def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collection_cycles_n_step( + self, + get_AsyncCollector, + ): + c1, env_lens = get_AsyncCollector + bufsize = 60 + ptr = [0, 0, 0, 0] + for n_step in tqdm.trange(1, 15, desc="test async n_step"): + result = c1.collect(n_step=n_step) + assert result.n_collected_steps >= n_step + for i, count in enumerate(np.bincount(result.lens, minlength=6)[2:]): + env_len = i + 2 + total = env_len * count + indices = np.arange(ptr[i], ptr[i] + total) % bufsize + ptr[i] = (ptr[i] + total) % bufsize + seq = np.arange(env_len) + buf = c1.buffer.buffers[i] + assert np.all(buf.info.env_id[indices] == i) + assert np.all(buf.obs[indices].reshape(count, env_len) == seq) + assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1) + + @pytest.mark.parametrize("gym_reset_kwargs", [None, {}]) + def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collection_cycles_first_n_episode_then_n_step( + self, + get_AsyncCollector, + gym_reset_kwargs, + ): + c1, env_lens = get_AsyncCollector + bufsize = 60 + ptr = [0, 0, 0, 0] + for n_episode in tqdm.trange(1, 30, desc="test async n_episode"): + result = c1.collect(n_episode=n_episode, gym_reset_kwargs=gym_reset_kwargs) + assert result.n_collected_episodes >= n_episode + # check buffer data, obs and obs_next, env_id + for i, count in enumerate(np.bincount(result.lens, minlength=6)[2:]): + env_len = i + 2 + total = env_len * count + indices = np.arange(ptr[i], ptr[i] + total) % bufsize + ptr[i] = (ptr[i] + total) % bufsize + seq = np.arange(env_len) + buf = c1.buffer.buffers[i] + assert np.all(buf.info.env_id[indices] == i) + assert np.all(buf.obs[indices].reshape(count, env_len) == seq) + assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1) + # test async n_step, for now the buffer should be full of data, thus no bincount stuff as above + for n_step in tqdm.trange(1, 15, desc="test async n_step"): + result = c1.collect(n_step=n_step, gym_reset_kwargs=gym_reset_kwargs) + assert result.n_collected_steps >= n_step + for i in range(4): + env_len = i + 2 + seq = np.arange(env_len) + buf = c1.buffer.buffers[i] + assert np.all(buf.info.env_id == i) + assert np.all(buf.obs.reshape(-1, env_len) == seq) + assert np.all(buf.obs_next.reshape(-1, env_len) == seq + 1) def test_collector_with_dict_state() -> None: - env = MyTestEnv(size=5, sleep=0, dict_state=True) - policy = MyPolicy(dict_state=True) - c0 = Collector(policy, env, ReplayBuffer(size=100), Logger.single_preprocess_fn) + env = MoveToRightEnv(size=5, sleep=0, dict_state=True) + policy = MaxActionPolicy(dict_state=True) + c0 = Collector(policy, env, ReplayBuffer(size=100)) + c0.reset() c0.collect(n_step=3) c0.collect(n_episode=2) - assert len(c0.buffer) == 10 - env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True) for i in [2, 3, 4, 5]] + assert len(c0.buffer) == 10 # 3 + two episodes with 5 steps each + env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0, dict_state=True) for i in [2, 3, 4, 5]] envs = DummyVectorEnv(env_fns) envs.seed(666) obs, info = envs.reset() @@ -280,8 +356,8 @@ def test_collector_with_dict_state() -> None: policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4), - Logger.single_preprocess_fn, ) + c1.reset() c1.collect(n_step=12) result = c1.collect(n_episode=8) assert result.n_collected_episodes == 8 @@ -396,41 +472,47 @@ def test_collector_with_dict_state() -> None: policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4), - Logger.single_preprocess_fn, ) + c2.reset() c2.collect(n_episode=10) batch, _ = c2.buffer.sample(10) -def test_collector_with_ma() -> None: - env = MyTestEnv(size=5, sleep=0, ma_rew=4) - policy = MyPolicy() - c0 = Collector(policy, env, ReplayBuffer(size=100), Logger.single_preprocess_fn) - # n_step=3 will collect a full episode - rew = c0.collect(n_step=3).returns - assert len(rew) == 0 - rew = c0.collect(n_episode=2).returns - assert rew.shape == (2, 4) - assert np.all(rew == 1) - env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4) for i in [2, 3, 4, 5]] +def test_collector_with_multi_agent() -> None: + multi_agent_env = MoveToRightEnv(size=5, sleep=0, ma_rew=4) + policy = MaxActionPolicy() + c_single_env = Collector(policy, multi_agent_env, ReplayBuffer(size=100)) + c_single_env.reset() + multi_env_returns = c_single_env.collect(n_step=3).returns + # c_single_env has length 3 + # We have no full episodes, so no returns yet + assert len(multi_env_returns) == 0 + + single_env_returns = c_single_env.collect(n_episode=2).returns + # now two episodes. Since we have 4 a agents, the returns have shape (2, 4) + assert single_env_returns.shape == (2, 4) + assert np.all(single_env_returns == 1) + + env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0, ma_rew=4) for i in [2, 3, 4, 5]] envs = DummyVectorEnv(env_fns) - c1 = Collector( + c_multi_env_ma = Collector( policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4), - Logger.single_preprocess_fn, ) - rew = c1.collect(n_step=12).returns - assert rew.shape == (2, 4) and np.all(rew == 1), rew - rew = c1.collect(n_episode=8).returns - assert rew.shape == (8, 4) - assert np.all(rew == 1) - batch, _ = c1.buffer.sample(10) + c_multi_env_ma.reset() + multi_env_returns = c_multi_env_ma.collect(n_step=12).returns + # each env makes 3 steps, the first two envs are done and result in two finished episodes + assert multi_env_returns.shape == (2, 4) and np.all(multi_env_returns == 1), multi_env_returns + multi_env_returns = c_multi_env_ma.collect(n_episode=8).returns + assert multi_env_returns.shape == (8, 4) + assert np.all(multi_env_returns == 1) + batch, _ = c_multi_env_ma.buffer.sample(10) print(batch) - c0.buffer.update(c1.buffer) - assert len(c0.buffer) in [42, 43] - if len(c0.buffer) == 42: - rew = [ + c_single_env.buffer.update(c_multi_env_ma.buffer) + assert len(c_single_env.buffer) in [42, 43] + if len(c_single_env.buffer) == 42: + multi_env_returns = [ 0, 0, 0, @@ -475,7 +557,7 @@ def test_collector_with_ma() -> None: 1, ] else: - rew = [ + multi_env_returns = [ 0, 0, 0, @@ -520,17 +602,17 @@ def test_collector_with_ma() -> None: 0, 1, ] - assert np.all(c0.buffer[:].rew == [[x] * 4 for x in rew]) - assert np.all(c0.buffer[:].done == rew) + assert np.all(c_single_env.buffer[:].rew == [[x] * 4 for x in multi_env_returns]) + assert np.all(c_single_env.buffer[:].done == multi_env_returns) c2 = Collector( policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4), - Logger.single_preprocess_fn, ) - rew = c2.collect(n_episode=10).returns - assert rew.shape == (10, 4) - assert np.all(rew == 1) + c2.reset() + multi_env_returns = c2.collect(n_episode=10).returns + assert multi_env_returns.shape == (10, 4) + assert np.all(multi_env_returns == 1) batch, _ = c2.buffer.sample(10) @@ -543,20 +625,21 @@ def test_collector_with_atari_setting() -> None: reference_obs[i, 0] = i # atari single buffer - env = MyTestEnv(size=5, sleep=0, array_state=True) - policy = MyPolicy() + env = MoveToRightEnv(size=5, sleep=0, array_state=True) + policy = MaxActionPolicy() c0 = Collector(policy, env, ReplayBuffer(size=100)) + c0.reset() c0.collect(n_step=6) c0.collect(n_episode=2) assert c0.buffer.obs.shape == (100, 4, 84, 84) assert c0.buffer.obs_next.shape == (100, 4, 84, 84) - assert len(c0.buffer) == 15 + assert len(c0.buffer) == 15 # 6 + 2 episodes with 5 steps each obs = np.zeros_like(c0.buffer.obs) obs[np.arange(15)] = reference_obs[np.arange(15) % 5] assert np.all(obs == c0.buffer.obs) c1 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=True)) - c1.collect(n_episode=3) + c1.collect(n_episode=3, reset_before_collect=True) assert np.allclose(c0.buffer.obs, c1.buffer.obs) with pytest.raises(AttributeError): c1.buffer.obs_next # noqa: B018 @@ -567,6 +650,7 @@ def test_collector_with_atari_setting() -> None: env, ReplayBuffer(size=100, ignore_obs_next=True, save_only_last_obs=True), ) + c2.reset() c2.collect(n_step=8) assert c2.buffer.obs.shape == (100, 84, 84) obs = np.zeros_like(c2.buffer.obs) @@ -575,9 +659,10 @@ def test_collector_with_atari_setting() -> None: assert np.allclose(c2.buffer[:].obs_next, reference_obs[[1, 2, 3, 4, 4, 1, 2, 2], -1]) # atari multi buffer - env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, array_state=True) for i in [2, 3, 4, 5]] + env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0, array_state=True) for i in [2, 3, 4, 5]] envs = DummyVectorEnv(env_fns) c3 = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4)) + c3.reset() c3.collect(n_step=12) result = c3.collect(n_episode=9) assert result.n_collected_episodes == 9 @@ -606,6 +691,7 @@ def test_collector_with_atari_setting() -> None: save_only_last_obs=True, ), ) + c4.reset() c4.collect(n_step=12) result = c4.collect(n_episode=9) assert result.n_collected_episodes == 9 @@ -672,6 +758,7 @@ def test_collector_with_atari_setting() -> None: buf = ReplayBuffer(100, stack_num=4, ignore_obs_next=True, save_only_last_obs=True) c5 = Collector(policy, envs, CachedReplayBuffer(buf, 4, 10)) + c5.reset() result_ = c5.collect(n_step=12) assert len(buf) == 5 assert len(c5.buffer) == 12 @@ -767,6 +854,7 @@ def test_collector_with_atari_setting() -> None: # test buffer=None c6 = Collector(policy, envs) + c6.reset() result1 = c6.collect(n_step=12) for key in ["n_collected_episodes", "n_collected_steps", "returns", "lens"]: assert np.allclose(getattr(result1, key), getattr(result_, key)) @@ -778,7 +866,7 @@ def test_collector_with_atari_setting() -> None: @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") def test_collector_envpool_gym_reset_return_info() -> None: envs = envpool.make_gymnasium("Pendulum-v1", num_envs=4, gym_reset_return_info=True) - policy = MyPolicy(action_shape=(len(envs), 1)) + policy = MaxActionPolicy(action_shape=(len(envs), 1)) c0 = Collector( policy, @@ -786,18 +874,59 @@ def test_collector_envpool_gym_reset_return_info() -> None: VectorReplayBuffer(len(envs) * 10, len(envs)), exploration_noise=True, ) + c0.reset() c0.collect(n_step=8) env_ids = np.zeros(len(envs) * 10) env_ids[[0, 1, 10, 11, 20, 21, 30, 31]] = [0, 0, 1, 1, 2, 2, 3, 3] assert np.allclose(c0.buffer.info["env_id"], env_ids) +def test_collector_with_vector_env(): + env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0) for i in [1, 8, 9, 10]] + + dum = DummyVectorEnv(env_fns) + policy = MaxActionPolicy() + + c2 = Collector( + policy, + dum, + VectorReplayBuffer(total_size=100, buffer_num=4), + ) + + c2.reset() + + c1r = c2.collect(n_episode=2) + assert np.array_equal(np.array([1, 8]), c1r.lens) + c2r = c2.collect(n_episode=10) + assert np.array_equal(np.array([1, 1, 1, 1, 1, 1, 1, 8, 9, 10]), c2r.lens) + c3r = c2.collect(n_step=20) + assert np.array_equal(np.array([1, 1, 1, 1, 1]), c3r.lens) + c4r = c2.collect(n_step=20) + assert np.array_equal(np.array([1, 1, 1, 8, 1, 9, 1, 10]), c4r.lens) + + +def test_async_collector_with_vector_env(): + env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0) for i in [1, 8, 9, 10]] + + dum = DummyVectorEnv(env_fns) + policy = MaxActionPolicy() + c1 = AsyncCollector( + policy, + dum, + VectorReplayBuffer(total_size=100, buffer_num=4), + ) + + c1r = c1.collect(n_episode=10, reset_before_collect=True) + assert np.array_equal(np.array([1, 1, 1, 1, 1, 1, 1, 1, 8, 1, 9]), c1r.lens) + c2r = c1.collect(n_step=20) + assert np.array_equal(np.array([1, 10, 1, 1, 1, 1]), c2r.lens) + + if __name__ == "__main__": - test_collector(gym_reset_kwargs=None) - test_collector(gym_reset_kwargs={}) + test_collector() test_collector_with_dict_state() - test_collector_with_ma() + test_collector_with_multi_agent() test_collector_with_atari_setting() - test_collector_with_async(gym_reset_kwargs=None) - test_collector_with_async(gym_reset_kwargs={"return_info": True}) test_collector_envpool_gym_reset_return_info() + test_collector_with_vector_env() + test_async_collector_with_vector_env() diff --git a/test/base/test_env.py b/test/base/test_env.py index edeb3f361..f1571ca8a 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -20,9 +20,9 @@ from tianshou.utils import RunningMeanStd if __name__ == "__main__": - from env import MyTestEnv, NXEnv + from env import MoveToRightEnv, NXEnv else: # pytest - from test.base.env import MyTestEnv, NXEnv + from test.base.env import MoveToRightEnv, NXEnv try: import envpool @@ -56,7 +56,7 @@ def recurse_comp(a, b): def test_async_env(size=10000, num=8, sleep=0.1) -> None: # simplify the test case, just keep stepping env_fns = [ - lambda i=i: MyTestEnv(size=i, sleep=sleep, random_sleep=True) + lambda i=i: MoveToRightEnv(size=i, sleep=sleep, random_sleep=True) for i in range(size, size + num) ] test_cls = [SubprocVectorEnv, ShmemVectorEnv] @@ -108,10 +108,10 @@ def test_async_env(size=10000, num=8, sleep=0.1) -> None: def test_async_check_id(size=100, num=4, sleep=0.2, timeout=0.7) -> None: env_fns = [ - lambda: MyTestEnv(size=size, sleep=sleep * 2), - lambda: MyTestEnv(size=size, sleep=sleep * 3), - lambda: MyTestEnv(size=size, sleep=sleep * 5), - lambda: MyTestEnv(size=size, sleep=sleep * 7), + lambda: MoveToRightEnv(size=size, sleep=sleep * 2), + lambda: MoveToRightEnv(size=size, sleep=sleep * 3), + lambda: MoveToRightEnv(size=size, sleep=sleep * 5), + lambda: MoveToRightEnv(size=size, sleep=sleep * 7), ] test_cls = [SubprocVectorEnv, ShmemVectorEnv] if has_ray(): @@ -156,7 +156,7 @@ def test_async_check_id(size=100, num=4, sleep=0.2, timeout=0.7) -> None: def test_vecenv(size=10, num=8, sleep=0.001) -> None: env_fns = [ - lambda i=i: MyTestEnv(size=i, sleep=sleep, recurse_state=True) + lambda i=i: MoveToRightEnv(size=i, sleep=sleep, recurse_state=True) for i in range(size, size + num) ] venv = [ @@ -237,7 +237,7 @@ def test_env_obs_dtype() -> None: def test_env_reset_optional_kwargs(size=10000, num=8) -> None: - env_fns = [lambda i=i: MyTestEnv(size=i) for i in range(size, size + num)] + env_fns = [lambda i=i: MoveToRightEnv(size=i) for i in range(size, size + num)] test_cls = [DummyVectorEnv, SubprocVectorEnv, ShmemVectorEnv] if has_ray(): test_cls += [RayVectorEnv] @@ -257,7 +257,7 @@ def test_venv_wrapper_gym(num_envs: int = 4) -> None: except ValueError: obs, info = envs.reset(return_info=True) assert isinstance(obs, np.ndarray) - assert isinstance(info, list) + assert isinstance(info, np.ndarray) assert isinstance(info[0], dict) assert obs.shape[0] == len(info) == num_envs @@ -334,7 +334,7 @@ def test_venv_norm_obs() -> None: action = np.array([1, 1, 1, 1]) total_step = 30 action_list = [action] * total_step - env_fns = [lambda i=x: MyTestEnv(size=i, array_state=True) for x in sizes] + env_fns = [lambda i=x: MoveToRightEnv(size=i, array_state=True) for x in sizes] raw = DummyVectorEnv(env_fns) train_env = VectorEnvNormObs(DummyVectorEnv(env_fns)) print(train_env.observation_space) diff --git a/test/base/test_env_finite.py b/test/base/test_env_finite.py index d1e780251..651e77082 100644 --- a/test/base/test_env_finite.py +++ b/test/base/test_env_finite.py @@ -90,20 +90,20 @@ def _get_default_info(self): # END - def reset(self, id=None): - id = self._wrap_id(id) + def reset(self, env_id=None): + env_id = self._wrap_id(env_id) self._reset_alive_envs() # ask super to reset alive envs and remap to current index - request_id = list(filter(lambda i: i in self._alive_env_ids, id)) - obs = [None] * len(id) - infos = [None] * len(id) - id2idx = {i: k for k, i in enumerate(id)} + request_id = list(filter(lambda i: i in self._alive_env_ids, env_id)) + obs = [None] * len(env_id) + infos = [None] * len(env_id) + id2idx = {i: k for k, i in enumerate(env_id)} if request_id: for k, o, info in zip(request_id, *super().reset(request_id), strict=True): obs[id2idx[k]] = o infos[id2idx[k]] = info - for i, o in zip(id, obs, strict=True): + for i, o in zip(env_id, obs, strict=True): if o is None and i in self._alive_env_ids: self._alive_env_ids.remove(i) @@ -121,7 +121,7 @@ def reset(self, id=None): self.reset() raise StopIteration - return np.stack(obs), infos + return np.stack(obs), np.array(infos) def step(self, action, id=None): id = self._wrap_id(id) @@ -204,10 +204,12 @@ def test_finite_dummy_vector_env() -> None: envs = FiniteSubprocVectorEnv([_finite_env_factory(dataset, 5, i) for i in range(5)]) policy = AnyPolicy() test_collector = Collector(policy, envs, exploration_noise=True) + test_collector.reset() for _ in range(3): envs.tracker = MetricTracker() try: + # TODO: why on earth 10**18? test_collector.collect(n_step=10**18) except StopIteration: envs.tracker.validate() @@ -218,6 +220,7 @@ def test_finite_subproc_vector_env() -> None: envs = FiniteSubprocVectorEnv([_finite_env_factory(dataset, 5, i) for i in range(5)]) policy = AnyPolicy() test_collector = Collector(policy, envs, exploration_noise=True) + test_collector.reset() for _ in range(3): envs.tracker = MetricTracker() diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index 20177f429..a180e44bb 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -136,6 +136,7 @@ def linear(x: int, y: int) -> nn.Module: exploration_noise=True, ) test_collector = Collector(policy, test_envs) + train_collector.reset() train_collector.collect(n_step=args.start_timesteps, random=True) # log log_path = os.path.join(args.logdir, args.task, "redq") diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 7b8690f94..fb1e28a83 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -162,6 +162,7 @@ def stop_fn(mean_rewards: float) -> bool: env = gym.make(args.task) policy.eval() collector = Collector(policy, env) + collector.reset() collector_stats = collector.collect(n_episode=1, render=args.render) print(collector_stats) diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 3ca7ce6cf..f51a8d75a 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -109,7 +109,9 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None: train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), ) + train_collector.reset() test_collector = Collector(policy, test_envs) + test_collector.reset() # log log_path = os.path.join(args.logdir, args.task, "a2c") writer = SummaryWriter(log_path) diff --git a/test/discrete/test_bdq.py b/test/discrete/test_bdq.py index a45b02114..295c6b378 100644 --- a/test/discrete/test_bdq.py +++ b/test/discrete/test_bdq.py @@ -108,7 +108,7 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: ) test_collector = Collector(policy, test_envs, exploration_noise=False) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) def train_fn(epoch: int, env_step: int) -> None: # exp decay eps = max(args.eps_train * (1 - args.eps_decay) ** env_step, args.eps_test) diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 013e2c414..c9731ed3b 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -120,7 +120,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) # log log_path = os.path.join(args.logdir, args.task, "c51") writer = SummaryWriter(log_path) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index d0aba10c4..de598e1d2 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -111,7 +111,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) # log log_path = os.path.join(args.logdir, args.task, "dqn") writer = SummaryWriter(log_path) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 3ed1f4fbe..03ece9bde 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -95,7 +95,7 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None: # the stack_num is for RNN training: sample framestack obs test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) # log log_path = os.path.join(args.logdir, args.task, "drqn") writer = SummaryWriter(log_path) diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index 7de090119..fa7a4ca4a 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -128,7 +128,7 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) # log log_path = os.path.join(args.logdir, args.task, "fqf") writer = SummaryWriter(log_path) diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index 5ff71f515..1f75ab516 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -124,7 +124,7 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None: train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) # log log_path = os.path.join(args.logdir, args.task, "iqn") writer = SummaryWriter(log_path) diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index c1bbcc3fa..b3e42d5e5 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -113,7 +113,7 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) # log log_path = os.path.join(args.logdir, args.task, "qrdqn") writer = SummaryWriter(log_path) diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index fafa1e03b..a38433ba4 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -128,7 +128,7 @@ def noisy_linear(x: int, y: int) -> NoisyLinear: train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) # log log_path = os.path.join(args.logdir, args.task, "rainbow") writer = SummaryWriter(log_path) diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index 0f957c75d..2aa5b4e9f 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -154,7 +154,7 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) # log log_path = os.path.join(args.logdir, args.task, "dqn_icm") writer = SummaryWriter(log_path) diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index 55719f47e..2994b11dd 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -81,7 +81,9 @@ def test_psrl(args: argparse.Namespace = get_args()) -> None: VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) + train_collector.reset() test_collector = Collector(policy, test_envs) + test_collector.reset() # Logger log_path = os.path.join(args.logdir, args.task, "psrl") writer = SummaryWriter(log_path) @@ -120,7 +122,6 @@ def stop_fn(mean_rewards: float) -> bool: # Let's watch its performance! policy.eval() test_envs.seed(args.seed) - test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) print(f"Final reward: {result.rew_mean}, length: {result.len_mean}") elif env.spec.reward_threshold: diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index 7f1a3128b..61450a932 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -115,9 +115,11 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector train_collector = Collector(policy, train_envs, buf, exploration_noise=True) + train_collector.reset() test_collector = Collector(policy, test_envs, exploration_noise=True) + test_collector.reset() # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) # log log_path = os.path.join(args.logdir, args.task, "qrdqn") writer = SummaryWriter(log_path) @@ -165,6 +167,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs)) policy.set_eps(0.2) collector = Collector(policy, test_envs, buf, exploration_noise=True) + collector.reset() collector_stats = collector.collect(n_step=args.buffer_size) if args.save_buffer_name.endswith(".hdf5"): buf.save_hdf5(args.save_buffer_name) diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index 32e6d5696..81fc899d4 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -178,6 +178,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: def test_discrete_bcq_resume(args: argparse.Namespace = get_args()) -> None: + test_discrete_bcq() args.resume = True test_discrete_bcq(args) diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index 0dd750b4c..990bf4694 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -83,8 +83,8 @@ def get_agents( if isinstance(env.observation_space, gym.spaces.Dict) else env.observation_space ) - args.state_shape = observation_space.shape or observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + args.state_shape = observation_space.shape or int(observation_space.n) + args.action_shape = env.action_space.shape or int(env.action_space.n) if agents is None: agents = [] optims = [] @@ -135,7 +135,7 @@ def train_agent( exploration_noise=True, ) test_collector = Collector(policy, test_envs, exploration_noise=True) - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) # log log_path = os.path.join(args.logdir, "pistonball", "dqn") writer = SummaryWriter(log_path) diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 8bbb20cfd..0897d73ad 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -234,7 +234,7 @@ def train_agent( exploration_noise=False, # True ) test_collector = Collector(policy, test_envs) - # train_collector.collect(n_step=args.batch_size * args.training_num) + # train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) # log log_path = os.path.join(args.logdir, "pistonball", "dqn") writer = SummaryWriter(log_path) diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index 62b66dfa4..e1559b113 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -102,8 +102,8 @@ def get_agents( if isinstance(env.observation_space, gymnasium.spaces.Dict) else env.observation_space ) - args.state_shape = observation_space.shape or observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + args.state_shape = observation_space.shape or int(observation_space.n) + args.action_shape = env.action_space.shape or int(env.action_space.n) if agent_learn is None: # model net = Net( @@ -170,7 +170,7 @@ def train_agent( ) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) # log log_path = os.path.join(args.logdir, "tic_tac_toe", "dqn") writer = SummaryWriter(log_path) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 7002b55d6..b9b702409 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -263,6 +263,9 @@ def __truediv__(self, value: Number | np.number) -> Self: def __repr__(self) -> str: ... + def __iter__(self) -> Iterator[Self]: + ... + def to_numpy(self) -> None: """Change all torch.Tensor to numpy.ndarray in-place.""" ... @@ -391,6 +394,12 @@ def split( """ ... + def to_dict(self) -> dict[str, Any]: + ... + + def to_list_of_dicts(self) -> list[dict[str, Any]]: + ... + class Batch(BatchProtocol): """See :class:`~tianshou.data.batch.BatchProtocol`.""" @@ -422,6 +431,17 @@ def __init__( # Feels like kwargs could be just merged into batch_dict in the beginning self.__init__(kwargs, copy=copy) # type: ignore + def to_dict(self) -> dict[str, Any]: + result = {} + for k, v in self.__dict__.items(): + if isinstance(v, Batch): + v = v.to_dict() + result[k] = v + return result + + def to_list_of_dicts(self) -> list[dict[str, Any]]: + return [entry.to_dict() for entry in self] + def __setattr__(self, key: str, value: Any) -> None: """Set self.key = value.""" self.__dict__[key] = _parse_value(value) @@ -478,6 +498,14 @@ def __getitem__(self, index: str | IndexType) -> Any: return new_batch raise IndexError("Cannot access item from empty Batch object.") + def __iter__(self) -> Iterator[Self]: + # TODO: empty batch raises an error on len and needs separate treatment, that's probably not a good idea + if len(self.__dict__) == 0: + yield from [] + else: + for i in range(len(self)): + yield self[i] + def __setitem__(self, index: str | IndexType, value: Any) -> None: """Assign value to self[index].""" value = _parse_value(value) @@ -601,10 +629,10 @@ def to_torch( else: # ndarray or scalar if not isinstance(obj, np.ndarray): - obj = np.asanyarray(obj) # noqa: PLW2901 - obj = torch.from_numpy(obj).to(device) # noqa: PLW2901 + obj = np.asanyarray(obj) + obj = torch.from_numpy(obj).to(device) if dtype is not None: - obj = obj.type(dtype) # noqa: PLW2901 + obj = obj.type(dtype) self.__dict__[batch_key] = obj def __cat(self, batches: Sequence[dict | Self], lens: list[int]) -> None: diff --git a/tianshou/data/buffer/manager.py b/tianshou/data/buffer/manager.py index e09b69667..a495b0ada 100644 --- a/tianshou/data/buffer/manager.py +++ b/tianshou/data/buffer/manager.py @@ -200,7 +200,7 @@ def sample_indices(self, batch_size: int | None) -> np.ndarray: return np.concatenate( [ - buf.sample_indices(bsz) + offset + buf.sample_indices(int(bsz)) + offset for offset, buf, bsz in zip(self._offset, self.buffers, sample_num, strict=True) ], ) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 09290c756..751fedfb2 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -1,8 +1,8 @@ import time import warnings -from collections.abc import Callable +from copy import copy from dataclasses import dataclass -from typing import Any, cast +from typing import Any, Self, TypeVar, cast import gymnasium as gym import numpy as np @@ -18,8 +18,10 @@ VectorReplayBuffer, to_numpy, ) -from tianshou.data.batch import alloc_by_keys_diff -from tianshou.data.types import RolloutBatchProtocol +from tianshou.data.types import ( + ObsBatchProtocol, + RolloutBatchProtocol, +) from tianshou.env import BaseVectorEnv, DummyVectorEnv from tianshou.policy import BasePolicy from tianshou.utils.print import DataclassPPrintMixin @@ -45,13 +47,80 @@ class CollectStats(CollectStatsBase): """The speed of collecting (env_step per second).""" returns: np.ndarray """The collected episode returns.""" - returns_stat: SequenceSummaryStats | None # can be None if no episode ends during collect step + returns_stat: SequenceSummaryStats | None # can be None if no episode ends during the collect step """Stats of the collected returns.""" lens: np.ndarray """The collected episode lengths.""" - lens_stat: SequenceSummaryStats | None # can be None if no episode ends during collect step + lens_stat: SequenceSummaryStats | None # can be None if no episode ends during the collect step """Stats of the collected episode lengths.""" + @classmethod + def with_autogenerated_stats( + cls, + returns: np.ndarray, + lens: np.ndarray, + n_collected_episodes: int = 0, + n_collected_steps: int = 0, + collect_time: float = 0.0, + collect_speed: float = 0.0, + ) -> Self: + """Return a new instance with the stats autogenerated from the given lists.""" + returns_stat = SequenceSummaryStats.from_sequence(returns) if returns.size > 0 else None + lens_stat = SequenceSummaryStats.from_sequence(lens) if lens.size > 0 else None + return cls( + n_collected_episodes=n_collected_episodes, + n_collected_steps=n_collected_steps, + collect_time=collect_time, + collect_speed=collect_speed, + returns=returns, + returns_stat=returns_stat, + lens=np.array(lens, int), + lens_stat=lens_stat, + ) + + +_TArrLike = TypeVar("_TArrLike", bound="np.ndarray | torch.Tensor | Batch | None") + + +def _nullable_slice(obj: _TArrLike, indices: np.ndarray) -> _TArrLike: + """Return None, or the values at the given indices if the object is not None.""" + if obj is not None: + return obj[indices] # type: ignore[index, return-value] + return None # type: ignore[unreachable] + + +def _dict_of_arr_to_arr_of_dicts(dict_of_arr: dict[str, np.ndarray | dict]) -> np.ndarray: + return np.array(Batch(dict_of_arr).to_list_of_dicts()) + + +def _HACKY_create_info_batch(info_array: np.ndarray) -> Batch: + """TODO: this exists because of multiple bugs in Batch and to restore backwards compatibility. + Batch should be fixed and this function should be removed asap!. + """ + if info_array.dtype != np.dtype("O"): + raise ValueError( + f"Expected info_array to have dtype=object, but got {info_array.dtype}.", + ) + + truthy_info_indices = info_array.nonzero()[0] + falsy_info_indices = set(range(len(info_array))) - set(truthy_info_indices) + falsy_info_indices = np.array(list(falsy_info_indices), dtype=int) + + if len(falsy_info_indices) == len(info_array): + return Batch() + + some_nonempty_info = None + for info in info_array: + if info: + some_nonempty_info = info + break + + info_array = copy(info_array) + info_array[falsy_info_indices] = some_nonempty_info + result_batch_parent = Batch(info=info_array) + result_batch_parent.info[falsy_info_indices] = {} + return result_batch_parent.info + class Collector: """Collector enables the policy to interact with different types of envs with exact number of steps or episodes. @@ -60,23 +129,13 @@ class Collector: :param env: a ``gym.Env`` environment or an instance of the :class:`~tianshou.env.BaseVectorEnv` class. :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. - If set to None, it will not store the data. Default to None. - :param function preprocess_fn: a function called before the data has been added to - the buffer, see issue #42 and :ref:`preprocess_fn`. Default to None. + If set to None, will instantiate a :class:`~tianshou.data.VectorReplayBuffer` + as the default buffer. :param exploration_noise: determine whether the action needs to be modified - with corresponding policy's exploration noise. If so, "policy. + with the corresponding policy's exploration noise. If so, "policy. exploration_noise(act, batch)" will be called automatically to add the exploration noise into action. Default to False. - The "preprocess_fn" is a function called before the data has been added to the - buffer with batch format. It will receive only "obs" and "env_id" when the - collector resets the environment, and will receive the keys "obs_next", "rew", - "terminated", "truncated, "info", "policy" and "env_id" in a normal env step. - Alternatively, it may also accept the keys "obs_next", "rew", "done", "info", - "policy" and "env_id". - It returns either a dict or a :class:`~tianshou.data.Batch` with the modified - keys and values. Examples are in "test/base/test_collector.py". - .. note:: Please make sure the given environment has a time limitation if using n_episode @@ -84,7 +143,7 @@ class Collector: .. note:: - In past versions of Tianshou, the replay buffer that was passed to `__init__` + In past versions of Tianshou, the replay buffer passed to `__init__` was automatically reset. This is not done in the current implementation. """ @@ -93,7 +152,6 @@ def __init__( policy: BasePolicy, env: gym.Env | BaseVectorEnv, buffer: ReplayBuffer | None = None, - preprocess_fn: Callable[..., RolloutBatchProtocol] | None = None, exploration_noise: bool = False, ) -> None: super().__init__() @@ -105,16 +163,30 @@ def __init__( self.env = env # type: ignore self.env_num = len(self.env) self.exploration_noise = exploration_noise - self.buffer: ReplayBuffer - self._assign_buffer(buffer) + self.buffer = self._assign_buffer(buffer) self.policy = policy - self.preprocess_fn = preprocess_fn self._action_space = self.env.action_space - self.data: RolloutBatchProtocol - # avoid creating attribute outside __init__ - self.reset(False) - def _assign_buffer(self, buffer: ReplayBuffer | None) -> None: + self._pre_collect_obs_RO: np.ndarray | None = None + self._pre_collect_info_R: np.ndarray | None = None + self._pre_collect_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None = None + + self._is_closed = False + self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0 + + def close(self) -> None: + """Close the collector and the environment.""" + self.env.close() + self._pre_collect_obs_RO = None + self._pre_collect_info_R = None + self._is_closed = True + + @property + def is_closed(self) -> bool: + """Return True if the collector is closed.""" + return self._is_closed + + def _assign_buffer(self, buffer: ReplayBuffer | None) -> ReplayBuffer: """Check if the buffer matches the constraint.""" if buffer is None: buffer = VectorReplayBuffer(self.env_num, self.env_num) @@ -136,38 +208,28 @@ def _assign_buffer(self, buffer: ReplayBuffer | None) -> None: f"{self.env_num} envs,\n\tplease use {vector_type}(total_size=" f"{buffer.maxsize}, buffer_num={self.env_num}, ...) instead.", ) - self.buffer = buffer + return buffer def reset( self, reset_buffer: bool = True, + reset_stats: bool = True, gym_reset_kwargs: dict[str, Any] | None = None, ) -> None: - """Reset the environment, statistics, current data and possibly replay memory. + """Reset the environment, statistics, and data needed to start the collection. - :param reset_buffer: if true, reset the replay buffer that is attached + :param reset_buffer: if true, reset the replay buffer attached to the collector. + :param reset_stats: if true, reset the statistics attached to the collector. :param gym_reset_kwargs: extra keyword arguments to pass into the environment's reset function. Defaults to None (extra keyword arguments) """ - # use empty Batch for "state" so that self.data supports slicing - # convert empty Batch to None when passing data to policy - data = Batch( - obs={}, - act={}, - rew={}, - terminated={}, - truncated={}, - done={}, - obs_next={}, - info={}, - policy={}, - ) - self.data = cast(RolloutBatchProtocol, data) - self.reset_env(gym_reset_kwargs) + self.reset_env(gym_reset_kwargs=gym_reset_kwargs) if reset_buffer: self.reset_buffer() - self.reset_stat() + if reset_stats: + self.reset_stat() + self._is_closed = False def reset_stat(self) -> None: """Reset the statistic variables.""" @@ -177,44 +239,76 @@ def reset_buffer(self, keep_statistics: bool = False) -> None: """Reset the data buffer.""" self.buffer.reset(keep_statistics=keep_statistics) - def reset_env(self, gym_reset_kwargs: dict[str, Any] | None = None) -> None: - """Reset all of the environments.""" - gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {} - obs, info = self.env.reset(**gym_reset_kwargs) - if self.preprocess_fn: - processed_data = self.preprocess_fn(obs=obs, info=info, env_id=np.arange(self.env_num)) - obs = processed_data.get("obs", obs) - info = processed_data.get("info", info) - self.data.info = info # type: ignore - self.data.obs = obs - - def _reset_state(self, id: int | list[int]) -> None: - """Reset the hidden state: self.data.state[id].""" - if hasattr(self.data.policy, "hidden_state"): - state = self.data.policy.hidden_state # it is a reference - if isinstance(state, torch.Tensor): - state[id].zero_() - elif isinstance(state, np.ndarray): - state[id] = None if state.dtype == object else 0 - elif isinstance(state, Batch): - state.empty_(id) - - def _reset_env_with_ids( + def reset_env( self, - local_ids: list[int] | np.ndarray, - global_ids: list[int] | np.ndarray, gym_reset_kwargs: dict[str, Any] | None = None, ) -> None: - gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {} - obs_reset, info = self.env.reset(global_ids, **gym_reset_kwargs) - if self.preprocess_fn: - processed_data = self.preprocess_fn(obs=obs_reset, info=info, env_id=global_ids) - obs_reset = processed_data.get("obs", obs_reset) - info = processed_data.get("info", info) - self.data.info[local_ids] = info # type: ignore + """Reset the environments and the initial obs, info, and hidden state of the collector.""" + gym_reset_kwargs = gym_reset_kwargs or {} + self._pre_collect_obs_RO, self._pre_collect_info_R = self.env.reset(**gym_reset_kwargs) + # TODO: hack, wrap envpool envs such that they don't return a dict + if isinstance(self._pre_collect_info_R, dict): # type: ignore[unreachable] + # this can happen if the env is an envpool env. Then the thing returned by reset is a dict + # with array entries instead of an array of dicts + # We use Batch to turn it into an array of dicts + self._pre_collect_info_R = _dict_of_arr_to_arr_of_dicts(self._pre_collect_info_R) # type: ignore[unreachable] + + self._pre_collect_hidden_state_RH = None + + def _compute_action_policy_hidden( + self, + random: bool, + ready_env_ids_R: np.ndarray, + use_grad: bool, + last_obs_RO: np.ndarray, + last_info_R: np.ndarray, + last_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None = None, + ) -> tuple[np.ndarray, np.ndarray, Batch, np.ndarray | torch.Tensor | Batch | None]: + """Returns the action, the normalized action, a "policy" entry, and the hidden state.""" + if random: + try: + act_normalized_RA = np.array( + [self._action_space[i].sample() for i in ready_env_ids_R], + ) + # TODO: test whether envpool env explicitly + except TypeError: # envpool's action space is not for per-env + act_normalized_RA = np.array([self._action_space.sample() for _ in ready_env_ids_R]) + act_RA = self.policy.map_action_inverse(np.array(act_normalized_RA)) + policy_R = Batch() + hidden_state_RH = None + + else: + info_batch = _HACKY_create_info_batch(last_info_R) + obs_batch_R = cast(ObsBatchProtocol, Batch(obs=last_obs_RO, info=info_batch)) - self.data.obs_next[local_ids] = obs_reset # type: ignore + with torch.set_grad_enabled(use_grad): + act_batch_RA = self.policy( + obs_batch_R, + last_hidden_state_RH, + ) + + act_RA = to_numpy(act_batch_RA.act) + if self.exploration_noise: + act_RA = self.policy.exploration_noise(act_RA, obs_batch_R) + act_normalized_RA = self.policy.map_action(act_RA) + + # TODO: cleanup the whole policy in batch thing + # todo policy_R can also be none, check + policy_R = act_batch_RA.get("policy", Batch()) + if not isinstance(policy_R, Batch): + raise RuntimeError( + f"The policy result should be a {Batch}, but got {type(policy_R)}", + ) + hidden_state_RH = act_batch_RA.get("state", None) + # TODO: do we need the conditional? Would be better to just add hidden_state which could be None + if hidden_state_RH is not None: + policy_R.hidden_state = ( + hidden_state_RH # save state into buffer through policy attr + ) + return act_RA, act_normalized_RA, policy_R, hidden_state_RH + + # TODO: reduce complexity, remove the noqa def collect( self, n_step: int | None = None, @@ -222,49 +316,74 @@ def collect( random: bool = False, render: float | None = None, no_grad: bool = True, + reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None, ) -> CollectStats: - """Collect a specified number of step or episode. + """Collect a specified number of steps or episodes. - To ensure unbiased sampling result with n_episode option, this function will + To ensure an unbiased sampling result with the n_episode option, this function will first collect ``n_episode - env_num`` episodes, then for the last ``env_num`` episodes, they will be collected evenly from each env. :param n_step: how many steps you want to collect. :param n_episode: how many episodes you want to collect. - :param random: whether to use random policy for collecting data. Default - to False. + :param random: whether to use random policy for collecting data. :param render: the sleep time between rendering consecutive frames. - Default to None (no rendering). - :param no_grad: whether to retain gradient in policy.forward(). Default to - True (no gradient retaining). + :param no_grad: whether to retain gradient in policy.forward(). + :param reset_before_collect: whether to reset the environment before + collecting data. + It has only an effect if n_episode is not None, i.e. + if one wants to collect a fixed number of episodes. + (The collector needs the initial obs and info to function properly.) :param gym_reset_kwargs: extra keyword arguments to pass into the environment's - reset function. Defaults to None (extra keyword arguments) + reset function. Only used if reset_before_collect is True. .. note:: One and only one collection number specification is permitted, either ``n_step`` or ``n_episode``. - :return: A dataclass object + :return: The collected stats """ + # NAMING CONVENTION (mostly suffixes): + # episode - An episode means a rollout until done (terminated or truncated). After an episode is completed, + # the corresponding env is either reset or removed from the ready envs. + # R - number ready env ids. Note that this might change when envs get idle. + # This can only happen in n_episode case, see explanation in the corresponding block. + # For n_step, we always use all envs to collect the data, while for n_episode, + # R will be at most n_episode at the beginning, but can decrease during the collection. + # O - dimension(s) of observations + # A - dimension(s) of actions + # H - dimension(s) of hidden state + # D - number of envs that reached done in the current collect iteration. Only relevant in n_episode case. + # S - number of surplus envs, i.e. envs that are ready but won't be used in the next iteration. + # Only used in n_episode case. Then, R becomes R-S. + + use_grad = not no_grad + gym_reset_kwargs = gym_reset_kwargs or {} + + # Input validation assert not self.env.is_async, "Please use AsyncCollector if using async venv." if n_step is not None: assert n_episode is None, ( f"Only one of n_step or n_episode is allowed in Collector." - f"collect, got n_step={n_step}, n_episode={n_episode}." + f"collect, got {n_step=}, {n_episode=}." ) assert n_step > 0 if n_step % self.env_num != 0: warnings.warn( - f"n_step={n_step} is not a multiple of #env ({self.env_num}), " - "which may cause extra transitions collected into the buffer.", + f"{n_step=} is not a multiple of ({self.env_num=}), " + "which may cause extra transitions being collected into the buffer.", ) - ready_env_ids = np.arange(self.env_num) + ready_env_ids_R = np.arange(self.env_num) elif n_episode is not None: assert n_episode > 0 - ready_env_ids = np.arange(min(self.env_num, n_episode)) - self.data = self.data[: min(self.env_num, n_episode)] + if self.env_num > n_episode: + warnings.warn( + f"{n_episode=} should be larger than {self.env_num=} to " + f"collect at least one trajectory in each environment.", + ) + ready_env_ids_R = np.arange(min(self.env_num, n_episode)) else: raise TypeError( "Please specify at least one (either n_step or n_episode) " @@ -273,149 +392,209 @@ def collect( start_time = time.time() + if reset_before_collect: + self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs) + + if self._pre_collect_obs_RO is None or self._pre_collect_info_R is None: + raise ValueError( + "Initial obs and info should not be None. " + "Either reset the collector (using reset or reset_env) or pass reset_before_collect=True to collect.", + ) + + # get the first obs to be the current obs in the n_step case as + # episodes as a new call to collect does not restart trajectories + # (which we also really don't want) step_count = 0 - episode_count = 0 + num_collected_episodes = 0 episode_returns: list[float] = [] episode_lens: list[int] = [] episode_start_indices: list[int] = [] + # in case we select fewer episodes than envs, we run only some of them + last_obs_RO = _nullable_slice(self._pre_collect_obs_RO, ready_env_ids_R) + last_info_R = _nullable_slice(self._pre_collect_info_R, ready_env_ids_R) + last_hidden_state_RH = _nullable_slice( + self._pre_collect_hidden_state_RH, + ready_env_ids_R, + ) + while True: - assert len(self.data) == len(ready_env_ids) + # todo check if we need this when using cur_rollout_batch + # if len(cur_rollout_batch) != len(ready_env_ids): + # raise RuntimeError( + # f"The length of the collected_rollout_batch {len(cur_rollout_batch)}) is not equal to the length of ready_env_ids" + # f"{len(ready_env_ids)}. This should not happen and could be a bug!", + # ) # restore the state: if the last state is None, it won't store - last_state = self.data.policy.pop("hidden_state", None) # get the next action - if random: - try: - act_sample = [self._action_space[i].sample() for i in ready_env_ids] - except TypeError: # envpool's action space is not for per-env - act_sample = [self._action_space.sample() for _ in ready_env_ids] - act_sample = self.policy.map_action_inverse(act_sample) # type: ignore - self.data.update(act=act_sample) - else: - if no_grad: - with torch.no_grad(): # faster than retain_grad version - # self.data.obs will be used by agent to get result - result = self.policy(self.data, last_state) - else: - result = self.policy(self.data, last_state) - # update state / act / policy into self.data - policy = result.get("policy", Batch()) - assert isinstance(policy, Batch) - state = result.get("state", None) - if state is not None: - policy.hidden_state = state # save state into buffer - act = to_numpy(result.act) - if self.exploration_noise: - act = self.policy.exploration_noise(act, self.data) - self.data.update(policy=policy, act=act) - - # get bounded and remapped actions first (not saved into buffer) - action_remap = self.policy.map_action(self.data.act) - # step in env + ( + act_RA, + act_normalized_RA, + policy_R, + hidden_state_RH, + ) = self._compute_action_policy_hidden( + random=random, + ready_env_ids_R=ready_env_ids_R, + use_grad=use_grad, + last_obs_RO=last_obs_RO, + last_info_R=last_info_R, + last_hidden_state_RH=last_hidden_state_RH, + ) - obs_next, rew, terminated, truncated, info = self.env.step( - action_remap, - ready_env_ids, + obs_next_RO, rew_R, terminated_R, truncated_R, info_R = self.env.step( + act_normalized_RA, + ready_env_ids_R, ) - done = np.logical_or(terminated, truncated) - - self.data.update( - obs_next=obs_next, - rew=rew, - terminated=terminated, - truncated=truncated, - done=done, - info=info, + if isinstance(info_R, dict): # type: ignore[unreachable] + # This can happen if the env is an envpool env. Then the info returned by step is a dict + info_R = _dict_of_arr_to_arr_of_dicts(info_R) # type: ignore[unreachable] + done_R = np.logical_or(terminated_R, truncated_R) + + current_iteration_batch = cast( + RolloutBatchProtocol, + Batch( + obs=last_obs_RO, + act=act_RA, + policy=policy_R, + obs_next=obs_next_RO, + rew=rew_R, + terminated=terminated_R, + truncated=truncated_R, + done=done_R, + info=info_R, + ), ) - if self.preprocess_fn: - self.data.update( - self.preprocess_fn( - obs_next=self.data.obs_next, - rew=self.data.rew, - done=self.data.done, - info=self.data.info, - policy=self.data.policy, - env_id=ready_env_ids, - act=self.data.act, - ), - ) + # TODO: only makes sense if render_mode is human. + # Also, doubtful whether it makes sense at all for true vectorized envs if render: self.env.render() - if render > 0 and not np.isclose(render, 0): + if not np.isclose(render, 0): time.sleep(render) # add data into the buffer - ptr, ep_rew, ep_len, ep_idx = self.buffer.add(self.data, buffer_ids=ready_env_ids) + ptr_R, ep_rew_R, ep_len_R, ep_idx_R = self.buffer.add( + current_iteration_batch, + buffer_ids=ready_env_ids_R, + ) # collect statistics - step_count += len(ready_env_ids) - - if np.any(done): - env_ind_local = np.where(done)[0] - env_ind_global = ready_env_ids[env_ind_local] - episode_count += len(env_ind_local) - episode_lens.extend(ep_len[env_ind_local]) - episode_returns.extend(ep_rew[env_ind_local]) - episode_start_indices.extend(ep_idx[env_ind_local]) + num_episodes_done_this_iter = np.sum(done_R) + num_collected_episodes += num_episodes_done_this_iter + step_count += len(ready_env_ids_R) + + # preparing for the next iteration + # obs_next, info and hidden_state will be modified inplace in the code below, so we copy to not affect the data in the buffer + last_obs_RO = copy(obs_next_RO) + last_info_R = copy(info_R) + last_hidden_state_RH = copy(hidden_state_RH) + + # Preparing last_obs_RO, last_info_R, last_hidden_state_RH for the next while-loop iteration + # Resetting envs that reached done, or removing some of them from the collection if needed (see below) + if num_episodes_done_this_iter > 0: + # TODO: adjust the whole index story, don't use np.where, just slice with boolean arrays + # D - number of envs that reached done in the rollout above + env_ind_local_D = np.where(done_R)[0] + env_ind_global_D = ready_env_ids_R[env_ind_local_D] + episode_lens.extend(ep_len_R[env_ind_local_D]) + episode_returns.extend(ep_rew_R[env_ind_local_D]) + episode_start_indices.extend(ep_idx_R[env_ind_local_D]) # now we copy obs_next to obs, but since there might be # finished episodes, we have to reset finished envs first. - self._reset_env_with_ids(env_ind_local, env_ind_global, gym_reset_kwargs) - for i in env_ind_local: - self._reset_state(i) - # remove surplus env id from ready_env_ids - # to avoid bias in selecting environments + obs_reset_DO, info_reset_D = self.env.reset( + env_id=env_ind_global_D, + **gym_reset_kwargs, + ) + + # Set the hidden state to zero or None for the envs that reached done + # TODO: does it have to be so complicated? We should have a single clear type for hidden_state instead of + # this complex logic + self._reset_hidden_state_based_on_type(env_ind_local_D, last_hidden_state_RH) + + # preparing for the next iteration + last_obs_RO[env_ind_local_D] = obs_reset_DO + last_info_R[env_ind_local_D] = info_reset_D + + # Handling the case when we have more ready envs than desired and are not done yet + # + # This can only happen if we are collecting a fixed number of episodes + # If we have more ready envs than there are remaining episodes to collect, + # we will remove some of them for the next rollout + # One effect of this is the following: only envs that have completed an episode + # in the last step can ever be removed from the ready envs. + # Thus, this guarantees that each env will contribute at least one episode to the + # collected data (the buffer). This effect was previous called "avoiding bias in selecting environments" + # However, it is not at all clear whether this is actually useful or necessary. + # Additional naming convention: + # S - number of surplus envs + # TODO: can the whole block be removed? If we have too many episodes, we could just strip the last ones. + # Changing R to R-S highly increases the complexity of the code. if n_episode: - surplus_env_num = len(ready_env_ids) - (n_episode - episode_count) + remaining_episodes_to_collect = n_episode - num_collected_episodes + surplus_env_num = len(ready_env_ids_R) - remaining_episodes_to_collect if surplus_env_num > 0: - mask = np.ones_like(ready_env_ids, dtype=bool) - mask[env_ind_local[:surplus_env_num]] = False - ready_env_ids = ready_env_ids[mask] - self.data = self.data[mask] - - self.data.obs = self.data.obs_next - - if (n_step and step_count >= n_step) or (n_episode and episode_count >= n_episode): + # R becomes R-S here, preparing for the next iteration in while loop + # Everything that was of length R needs to be filtered and become of length R-S. + # Note that this won't be the last iteration, as one iteration equals one + # step and we still need to collect the remaining episodes to reach the breaking condition. + + # creating the mask + env_to_be_ignored_ind_local_S = env_ind_local_D[:surplus_env_num] + env_should_remain_R = np.ones_like(ready_env_ids_R, dtype=bool) + env_should_remain_R[env_to_be_ignored_ind_local_S] = False + # stripping the "idle" indices, shortening the relevant quantities from R to R-S + ready_env_ids_R = ready_env_ids_R[env_should_remain_R] + last_obs_RO = last_obs_RO[env_should_remain_R] + last_info_R = last_info_R[env_should_remain_R] + if hidden_state_RH is not None: + last_hidden_state_RH = last_hidden_state_RH[env_should_remain_R] # type: ignore[index] + + if (n_step and step_count >= n_step) or ( + n_episode and num_collected_episodes >= n_episode + ): break # generate statistics self.collect_step += step_count - self.collect_episode += episode_count + self.collect_episode += num_collected_episodes collect_time = max(time.time() - start_time, 1e-9) self.collect_time += collect_time - if n_episode: - data = Batch( - obs={}, - act={}, - rew={}, - terminated={}, - truncated={}, - done={}, - obs_next={}, - info={}, - policy={}, - ) - self.data = cast(RolloutBatchProtocol, data) - self.reset_env() + if n_step: + # persist for future collect iterations + self._pre_collect_obs_RO = last_obs_RO + self._pre_collect_info_R = last_info_R + self._pre_collect_hidden_state_RH = last_hidden_state_RH + elif n_episode: + # reset envs and the _pre_collect fields + self.reset_env(gym_reset_kwargs) # todo still necessary? - return CollectStats( - n_collected_episodes=episode_count, + return CollectStats.with_autogenerated_stats( + returns=np.array(episode_returns), + lens=np.array(episode_lens), + n_collected_episodes=num_collected_episodes, n_collected_steps=step_count, collect_time=collect_time, collect_speed=step_count / collect_time, - returns=np.array(episode_returns), - returns_stat=SequenceSummaryStats.from_sequence(episode_returns) - if len(episode_returns) > 0 - else None, - lens=np.array(episode_lens, int), - lens_stat=SequenceSummaryStats.from_sequence(episode_lens) - if len(episode_lens) > 0 - else None, ) + def _reset_hidden_state_based_on_type( + self, + env_ind_local_D: np.ndarray, + last_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None, + ) -> None: + if isinstance(last_hidden_state_RH, torch.Tensor): + last_hidden_state_RH[env_ind_local_D].zero_() # type: ignore[index] + elif isinstance(last_hidden_state_RH, np.ndarray): + last_hidden_state_RH[env_ind_local_D] = ( + None if last_hidden_state_RH.dtype == object else 0 + ) + elif isinstance(last_hidden_state_RH, Batch): + last_hidden_state_RH.empty_(env_ind_local_D) + # todo is this inplace magic and just working? + class AsyncCollector(Collector): """Async Collector handles async vector environment. @@ -429,7 +608,6 @@ def __init__( policy: BasePolicy, env: BaseVectorEnv, buffer: ReplayBuffer | None = None, - preprocess_fn: Callable[..., RolloutBatchProtocol] | None = None, exploration_noise: bool = False, ) -> None: # assert env.is_async @@ -438,13 +616,48 @@ def __init__( policy, env, buffer, - preprocess_fn, exploration_noise, ) + # E denotes the number of parallel environments: self.env_num + # At init, E=R but during collection R <= E + # Keep in sync with reset! + self._ready_env_ids_R: np.ndarray = np.arange(self.env_num) + self._current_obs_in_all_envs_EO: np.ndarray | None = copy(self._pre_collect_obs_RO) + self._current_info_in_all_envs_E: np.ndarray | None = copy(self._pre_collect_info_R) + self._current_hidden_state_in_all_envs_EH: np.ndarray | torch.Tensor | Batch | None = copy( + self._pre_collect_hidden_state_RH, + ) + self._current_action_in_all_envs_EA: np.ndarray = np.empty(self.env_num) + self._current_policy_in_all_envs_E: Batch | None = None - def reset_env(self, gym_reset_kwargs: dict[str, Any] | None = None) -> None: - super().reset_env(gym_reset_kwargs) - self._ready_env_ids = np.arange(self.env_num) + def reset( + self, + reset_buffer: bool = True, + reset_stats: bool = True, + gym_reset_kwargs: dict[str, Any] | None = None, + ) -> None: + """Reset the environment, statistics, and data needed to start the collection. + + :param reset_buffer: if true, reset the replay buffer attached + to the collector. + :param reset_stats: if true, reset the statistics attached to the collector. + :param gym_reset_kwargs: extra keyword arguments to pass into the environment's + reset function. Defaults to None (extra keyword arguments) + """ + # This sets the _pre_collect attrs + super().reset( + reset_buffer=reset_buffer, + reset_stats=reset_stats, + gym_reset_kwargs=gym_reset_kwargs, + ) + # Keep in sync with init! + self._ready_env_ids_R = np.arange(self.env_num) + # E denotes the number of parallel environments self.env_num + self._current_obs_in_all_envs_EO = copy(self._pre_collect_obs_RO) + self._current_info_in_all_envs_E = copy(self._pre_collect_info_R) + self._current_hidden_state_in_all_envs_EH = copy(self._pre_collect_hidden_state_RH) + self._current_action_in_all_envs_EA = np.empty(self.env_num) + self._current_policy_in_all_envs_E = None def collect( self, @@ -453,22 +666,27 @@ def collect( random: bool = False, render: float | None = None, no_grad: bool = True, + reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None, ) -> CollectStats: - """Collect a specified number of step or episode with async env setting. + """Collect a specified number of steps or episodes with async env setting. - This function doesn't collect exactly n_step or n_episode number of - transitions. Instead, in order to support async setting, it may collect more - than given n_step or n_episode transitions and save into buffer. + This function does not collect an exact number of transitions specified by n_step or + n_episode. Instead, to support the asynchronous setting, it may collect more transitions + than requested by n_step or n_episode and save them into the buffer. :param n_step: how many steps you want to collect. :param n_episode: how many episodes you want to collect. - :param random: whether to use random policy for collecting data. Default + :param random: whether to use random policy_R for collecting data. Default to False. :param render: the sleep time between rendering consecutive frames. Default to None (no rendering). - :param no_grad: whether to retain gradient in policy.forward(). Default to + :param no_grad: whether to retain gradient in policy_R.forward(). Default to True (no gradient retaining). + :param reset_before_collect: whether to reset the environment before + collecting data. It has only an effect if n_episode is not None, i.e. + if one wants to collect a fixed number of episodes. + (The collector needs the initial obs and info to function properly.) :param gym_reset_kwargs: extra keyword arguments to pass into the environment's reset function. Defaults to None (extra keyword arguments) @@ -479,6 +697,9 @@ def collect( :return: A dataclass object """ + use_grad = not no_grad + gym_reset_kwargs = gym_reset_kwargs or {} + # collect at least n_step or n_episode if n_step is not None: assert n_episode is None, ( @@ -494,104 +715,123 @@ def collect( "in AsyncCollector.collect().", ) - ready_env_ids = self._ready_env_ids + if reset_before_collect: + # first we need to step all envs to be able to interact with them + if self.env.waiting_id: + self.env.step(None, id=self.env.waiting_id) + self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs) start_time = time.time() step_count = 0 - episode_count = 0 + num_collected_episodes = 0 episode_returns: list[float] = [] episode_lens: list[int] = [] episode_start_indices: list[int] = [] + ready_env_ids_R = self._ready_env_ids_R + # last_obs_RO= self._current_obs_in_all_envs_EO[ready_env_ids_R] # type: ignore[index] + # last_info_R = self._current_info_in_all_envs_E[ready_env_ids_R] # type: ignore[index] + # last_hidden_state_RH = self._current_hidden_state_in_all_envs_EH[ready_env_ids_R] # type: ignore[index] + # last_obs_RO = self._pre_collect_obs_RO + # last_info_R = self._pre_collect_info_R + # last_hidden_state_RH = self._pre_collect_hidden_state_RH + if self._current_obs_in_all_envs_EO is None or self._current_info_in_all_envs_E is None: + raise RuntimeError( + "Current obs or info array is None, did you call reset or pass reset_at_collect=True?", + ) + + last_obs_RO = self._current_obs_in_all_envs_EO[ready_env_ids_R] + last_info_R = self._current_info_in_all_envs_E[ready_env_ids_R] + last_hidden_state_RH = _nullable_slice( + self._current_hidden_state_in_all_envs_EH, + ready_env_ids_R, + ) + # Each iteration of the AsyncCollector is only stepping a subset of the + # envs. The last observation/ hidden state of the ones not included in + # the current iteration has to be retained. while True: - whole_data = self.data - self.data = self.data[ready_env_ids] - assert len(whole_data) == self.env_num # major difference - # restore the state: if the last state is None, it won't store - last_state = self.data.policy.pop("hidden_state", None) + # todo do we need this? + # todo extend to all current attributes but some could be None at init + if self._current_obs_in_all_envs_EO is None: + raise RuntimeError( + "Current obs is None, did you call reset or pass reset_at_collect=True?", + ) + if ( + not len(self._current_obs_in_all_envs_EO) + == len(self._current_action_in_all_envs_EA) + == self.env_num + ): # major difference + raise RuntimeError( + f"{len(self._current_obs_in_all_envs_EO)=} and" + f"{len(self._current_action_in_all_envs_EA)=} have to equal" + f" {self.env_num=} as it tracks the current transition" + f"in all envs", + ) # get the next action - if random: - try: - act_sample = [self._action_space[i].sample() for i in ready_env_ids] - except TypeError: # envpool's action space is not for per-env - act_sample = [self._action_space.sample() for _ in ready_env_ids] - act_sample = self.policy.map_action_inverse(act_sample) # type: ignore - self.data.update(act=act_sample) + ( + act_RA, + act_normalized_RA, + policy_R, + hidden_state_RH, + ) = self._compute_action_policy_hidden( + random=random, + ready_env_ids_R=ready_env_ids_R, + use_grad=use_grad, + last_obs_RO=last_obs_RO, + last_info_R=last_info_R, + last_hidden_state_RH=last_hidden_state_RH, + ) + + # save act_RA/policy_R/ hidden_state_RH before env.step + self._current_action_in_all_envs_EA[ready_env_ids_R] = act_RA + if self._current_policy_in_all_envs_E: + self._current_policy_in_all_envs_E[ready_env_ids_R] = policy_R else: - if no_grad: - with torch.no_grad(): # faster than retain_grad version - # self.data.obs will be used by agent to get result - result = self.policy(self.data, last_state) + self._current_policy_in_all_envs_E = policy_R # first iteration + if hidden_state_RH is not None: + if self._current_hidden_state_in_all_envs_EH is not None: + # Need to cast since if it's a Tensor, the assignment might in fact fail if hidden_state_RH is not + # a tensor as well. This is hard to express with proper typing, even using @overload, so we cheat + # and hope that if one of the two is a tensor, the other one is as well. + self._current_hidden_state_in_all_envs_EH = cast( + np.ndarray | Batch, + self._current_hidden_state_in_all_envs_EH, + ) + self._current_hidden_state_in_all_envs_EH[ready_env_ids_R] = hidden_state_RH else: - result = self.policy(self.data, last_state) - # update state / act / policy into self.data - policy = result.get("policy", Batch()) - assert isinstance(policy, Batch) - state = result.get("state", None) - if state is not None: - policy.hidden_state = state # save state into buffer - act = to_numpy(result.act) - if self.exploration_noise: - act = self.policy.exploration_noise(act, self.data) - self.data.update(policy=policy, act=act) - - # save act/policy before env.step - try: - whole_data.act[ready_env_ids] = self.data.act # type: ignore - whole_data.policy[ready_env_ids] = self.data.policy - except ValueError: - alloc_by_keys_diff(whole_data, self.data, self.env_num, False) - whole_data[ready_env_ids] = self.data # lots of overhead - - # get bounded and remapped actions first (not saved into buffer) - action_remap = self.policy.map_action(self.data.act) + self._current_hidden_state_in_all_envs_EH = hidden_state_RH + # step in env - obs_next, rew, terminated, truncated, info = self.env.step( - action_remap, - ready_env_ids, + obs_next_RO, rew_R, terminated_R, truncated_R, info_R = self.env.step( + act_normalized_RA, + ready_env_ids_R, ) - done = np.logical_or(terminated, truncated) - - # change self.data here because ready_env_ids has changed + done_R = np.logical_or(terminated_R, truncated_R) + # Not all environments of the AsyncCollector might have performed a step in this iteration. + # Change batch_of_envs_with_step_in_this_iteration here to reflect that ready_env_ids_R has changed. + # This means especially that R is potentially changing every iteration try: - ready_env_ids = info["env_id"] + ready_env_ids_R = cast(np.ndarray, info_R["env_id"]) + # TODO: don't use bare Exception! except Exception: - ready_env_ids = np.array([i["env_id"] for i in info]) - self.data = whole_data[ready_env_ids] - - self.data.update( - obs_next=obs_next, - rew=rew, - terminated=terminated, - truncated=truncated, - info=info, + ready_env_ids_R = np.array([i["env_id"] for i in info_R]) + + current_iteration_batch = cast( + RolloutBatchProtocol, + Batch( + obs=self._current_obs_in_all_envs_EO[ready_env_ids_R], + act=self._current_action_in_all_envs_EA[ready_env_ids_R], + policy=self._current_policy_in_all_envs_E[ready_env_ids_R], + obs_next=obs_next_RO, + rew=rew_R, + terminated=terminated_R, + truncated=truncated_R, + done=done_R, + info=info_R, + ), ) - if self.preprocess_fn: - try: - self.data.update( - self.preprocess_fn( - obs_next=self.data.obs_next, - rew=self.data.rew, - terminated=self.data.terminated, - truncated=self.data.truncated, - info=self.data.info, - env_id=ready_env_ids, - act=self.data.act, - ), - ) - except TypeError: - self.data.update( - self.preprocess_fn( - obs_next=self.data.obs_next, - rew=self.data.rew, - done=self.data.done, - info=self.data.info, - env_id=ready_env_ids, - act=self.data.act, - ), - ) if render: self.env.render() @@ -599,60 +839,77 @@ def collect( time.sleep(render) # add data into the buffer - ptr, ep_rew, ep_len, ep_idx = self.buffer.add(self.data, buffer_ids=ready_env_ids) + ptr_R, ep_rew_R, ep_len_R, ep_idx_R = self.buffer.add( + current_iteration_batch, + buffer_ids=ready_env_ids_R, + ) # collect statistics - step_count += len(ready_env_ids) - - if np.any(done): - env_ind_local = np.where(done)[0] - env_ind_global = ready_env_ids[env_ind_local] - episode_count += len(env_ind_local) - episode_lens.extend(ep_len[env_ind_local]) - episode_returns.extend(ep_rew[env_ind_local]) - episode_start_indices.extend(ep_idx[env_ind_local]) - # now we copy obs_next to obs, but since there might be + num_episodes_done_this_iter = np.sum(done_R) + step_count += len(ready_env_ids_R) + num_collected_episodes += num_episodes_done_this_iter + + # preparing for the next iteration + # todo do we need the copy stuff (tests pass also without) + # todo seem we can get rid of this last_sth stuff altogether + last_obs_RO = copy(obs_next_RO) + last_info_R = copy(info_R) + last_hidden_state_RH = copy(self._current_hidden_state_in_all_envs_EH[ready_env_ids_R]) # type: ignore[index] + + if num_episodes_done_this_iter: + env_ind_local_D = np.where(done_R)[0] + env_ind_global_D = ready_env_ids_R[env_ind_local_D] + episode_lens.extend(ep_len_R[env_ind_local_D]) + episode_returns.extend(ep_rew_R[env_ind_local_D]) + episode_start_indices.extend(ep_idx_R[env_ind_local_D]) + + # now we copy obs_next_RO to obs, but since there might be # finished episodes, we have to reset finished envs first. - self._reset_env_with_ids(env_ind_local, env_ind_global, gym_reset_kwargs) - for i in env_ind_local: - self._reset_state(i) + obs_reset_DO, info_reset_D = self.env.reset( + env_id=env_ind_global_D, + **gym_reset_kwargs, + ) + last_obs_RO[env_ind_local_D] = obs_reset_DO + last_info_R[env_ind_local_D] = info_reset_D + + self._reset_hidden_state_based_on_type(env_ind_local_D, last_hidden_state_RH) + + # update based on the current transition in all envs + self._current_obs_in_all_envs_EO[ready_env_ids_R] = last_obs_RO + # this is a list, so loop over + for idx, ready_env_id in enumerate(ready_env_ids_R): + self._current_info_in_all_envs_E[ready_env_id] = last_info_R[idx] + if self._current_hidden_state_in_all_envs_EH is not None: + # Need to cast since if it's a Tensor, the assignment might in fact fail if hidden_state_RH is not + # a tensor as well. This is hard to express with proper typing, even using @overload, so we cheat + # and hope that if one of the two is a tensor, the other one is as well. + self._current_hidden_state_in_all_envs_EH = cast( + np.ndarray | Batch, + self._current_hidden_state_in_all_envs_EH, + ) + self._current_hidden_state_in_all_envs_EH[ready_env_ids_R] = last_hidden_state_RH + else: + self._current_hidden_state_in_all_envs_EH = last_hidden_state_RH - try: - # Need to ignore types b/c according to mypy Tensors cannot be indexed - # by arrays (which they can...) - whole_data.obs[ready_env_ids] = self.data.obs_next # type: ignore - whole_data.rew[ready_env_ids] = self.data.rew - whole_data.done[ready_env_ids] = self.data.done - whole_data.info[ready_env_ids] = self.data.info # type: ignore - except ValueError: - alloc_by_keys_diff(whole_data, self.data, self.env_num, False) - self.data.obs = self.data.obs_next - # lots of overhead - whole_data[ready_env_ids] = self.data - self.data = whole_data - - if (n_step and step_count >= n_step) or (n_episode and episode_count >= n_episode): + if (n_step and step_count >= n_step) or ( + n_episode and num_collected_episodes >= n_episode + ): break - self._ready_env_ids = ready_env_ids - # generate statistics self.collect_step += step_count - self.collect_episode += episode_count + self.collect_episode += num_collected_episodes collect_time = max(time.time() - start_time, 1e-9) self.collect_time += collect_time - return CollectStats( - n_collected_episodes=episode_count, + # persist for future collect iterations + self._ready_env_ids_R = ready_env_ids_R + + return CollectStats.with_autogenerated_stats( + returns=np.array(episode_returns), + lens=np.array(episode_lens), + n_collected_episodes=num_collected_episodes, n_collected_steps=step_count, collect_time=collect_time, collect_speed=step_count / collect_time, - returns=np.array(episode_returns), - returns_stat=SequenceSummaryStats.from_sequence(episode_returns) - if len(episode_returns) > 0 - else None, - lens=np.array(episode_lens, int), - lens_stat=SequenceSummaryStats.from_sequence(episode_lens) - if len(episode_lens) > 0 - else None, ) diff --git a/tianshou/data/utils/converter.py b/tianshou/data/utils/converter.py index 205c2d5f1..2df462da5 100644 --- a/tianshou/data/utils/converter.py +++ b/tianshou/data/utils/converter.py @@ -12,6 +12,7 @@ # TODO: confusing name, could actually return a batch... # Overrides and generic types should be added +# todo check for ActBatchProtocol @no_type_check def to_numpy(x: Any) -> Batch | np.ndarray: """Return an object without torch.Tensor.""" diff --git a/tianshou/env/venv_wrappers.py b/tianshou/env/venv_wrappers.py index 4e30e711b..9297ddebb 100644 --- a/tianshou/env/venv_wrappers.py +++ b/tianshou/env/venv_wrappers.py @@ -44,14 +44,14 @@ def set_env_attr( def reset( self, - id: int | list[int] | np.ndarray | None = None, + env_id: int | list[int] | np.ndarray | None = None, **kwargs: Any, - ) -> tuple[np.ndarray, dict | list[dict]]: - return self.venv.reset(id, **kwargs) + ) -> tuple[np.ndarray, np.ndarray]: + return self.venv.reset(env_id, **kwargs) def step( self, - action: np.ndarray | torch.Tensor, + action: np.ndarray | torch.Tensor | None, id: int | list[int] | np.ndarray | None = None, ) -> gym_new_venv_step_type: return self.venv.step(action, id) @@ -80,10 +80,10 @@ def __init__(self, venv: BaseVectorEnv, update_obs_rms: bool = True) -> None: def reset( self, - id: int | list[int] | np.ndarray | None = None, + env_id: int | list[int] | np.ndarray | None = None, **kwargs: Any, - ) -> tuple[np.ndarray, dict | list[dict]]: - obs, info = self.venv.reset(id, **kwargs) + ) -> tuple[np.ndarray, np.ndarray]: + obs, info = self.venv.reset(env_id, **kwargs) if isinstance(obs, tuple): # type: ignore raise TypeError( @@ -98,7 +98,7 @@ def reset( def step( self, - action: np.ndarray | torch.Tensor, + action: np.ndarray | torch.Tensor | None, id: int | list[int] | np.ndarray | None = None, ) -> gym_new_venv_step_type: step_results = self.venv.step(action, id) diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index dfcd12e85..e9309f9ec 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -190,11 +190,13 @@ def _assert_id(self, id: list[int] | np.ndarray) -> None: ), f"Cannot interact with environment {i} which is stepping now." assert i in self.ready_id, f"Can only interact with ready environments {self.ready_id}." + # TODO: for now, has to be kept in sync with reset in EnvPoolMixin + # In particular, can't rename env_id to env_ids def reset( self, - id: int | list[int] | np.ndarray | None = None, + env_id: int | list[int] | np.ndarray | None = None, **kwargs: Any, - ) -> tuple[np.ndarray, dict | list[dict]]: + ) -> tuple[np.ndarray, np.ndarray]: """Reset the state of some envs and return initial observations. If id is None, reset the state of all the environments and return @@ -202,14 +204,14 @@ def reset( the given id, either an int or a list. """ self._assert_is_not_closed() - id = self._wrap_id(id) + env_id = self._wrap_id(env_id) if self.is_async: - self._assert_id(id) + self._assert_id(env_id) # send(None) == reset() in worker - for i in id: - self.workers[i].send(None, **kwargs) - ret_list = [self.workers[i].recv() for i in id] + for id in env_id: + self.workers[id].send(None, **kwargs) + ret_list = [self.workers[id].recv() for id in env_id] assert ( isinstance(ret_list[0], tuple | list) @@ -229,12 +231,12 @@ def reset( except ValueError: # different len(obs) obs = np.array(obs_list, dtype=object) - infos = [r[1] for r in ret_list] - return obs, infos # type: ignore + infos = np.array([r[1] for r in ret_list]) + return obs, infos def step( self, - action: np.ndarray | torch.Tensor, + action: np.ndarray | torch.Tensor | None, id: int | list[int] | np.ndarray | None = None, ) -> gym_new_venv_step_type: """Run one timestep of some environments' dynamics. @@ -248,6 +250,8 @@ def step( batch_done, batch_info) in numpy format. :param numpy.ndarray action: a batch of action provided by the agent. + If the venv is async, the action can be None, which will result + in all arrays in the returned tuple being empty. :return: A tuple consisting of either: @@ -271,6 +275,8 @@ def step( self._assert_is_not_closed() id = self._wrap_id(id) if not self.is_async: + if action is None: + raise ValueError("action must be not-None for non-async") assert len(action) == len(id) for i, j in enumerate(id): self.workers[j].send(action[i]) diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index 39d367e5f..f71a7f981 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -93,7 +93,14 @@ def create_train_test_collector( self, policy: BasePolicy, envs: Environments, + reset_collectors: bool = True, ) -> tuple[Collector, Collector]: + """:param policy: + :param envs: + :param reset_collectors: Whether to reset the collectors before returning them. + Setting to True means that the envs will be reset as well. + :return: + """ buffer_size = self.sampling_config.buffer_size train_envs = envs.train_envs buffer: ReplayBuffer @@ -114,6 +121,10 @@ def create_train_test_collector( ) train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, envs.test_envs) + if reset_collectors: + train_collector.reset() + test_collector.reset() + if self.sampling_config.start_timesteps > 0: log.info( f"Collecting {self.sampling_config.start_timesteps} initial environment steps before training (random={self.sampling_config.start_timesteps_random})", diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 5b7d3883d..17f0550fc 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -311,7 +311,7 @@ def _watch_agent( ) -> None: policy.eval() collector = Collector(policy, env) - result = collector.collect(n_episode=num_episodes, render=render) + result = collector.collect(n_episode=num_episodes, render=render, reset_before_collect=True) assert result.returns_stat is not None # for mypy assert result.lens_stat is not None # for mypy log.info( diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 087e3cb5e..7df7ebd2c 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -18,6 +18,7 @@ from tianshou.data.buffer.base import TBuffer from tianshou.data.types import ( ActBatchProtocol, + ActStateBatchProtocol, BatchWithReturnsProtocol, ObsBatchProtocol, RolloutBatchProtocol, @@ -233,11 +234,14 @@ def set_agent_id(self, agent_id: int) -> None: # have a method to add noise to action. # So we add the default behavior here. It's a little messy, maybe one can # find a better way to do this. + + _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") + def exploration_noise( self, - act: np.ndarray | BatchProtocol, - batch: RolloutBatchProtocol, - ) -> np.ndarray | BatchProtocol: + act: _TArrOrActBatch, + batch: ObsBatchProtocol, + ) -> _TArrOrActBatch: """Modify the action from policy.forward with exploration noise. NOTE: currently does not add any noise! Needs to be overridden by subclasses @@ -287,7 +291,7 @@ def forward( batch: ObsBatchProtocol, state: dict | BatchProtocol | np.ndarray | None = None, **kwargs: Any, - ) -> ActBatchProtocol: + ) -> ActBatchProtocol | ActStateBatchProtocol: # TODO: make consistent typing """Compute action over the given batch data. :return: A :class:`~tianshou.data.Batch` which MUST have the following keys: diff --git a/tianshou/policy/modelbased/icm.py b/tianshou/policy/modelbased/icm.py index 54d560b9d..9a603b7de 100644 --- a/tianshou/policy/modelbased/icm.py +++ b/tianshou/policy/modelbased/icm.py @@ -1,4 +1,4 @@ -from typing import Any, Literal, Self +from typing import Any, Literal, Self, TypeVar import gymnasium as gym import numpy as np @@ -105,11 +105,13 @@ def forward( """ return self.policy.forward(batch, state, **kwargs) + _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") + def exploration_noise( self, - act: np.ndarray | BatchProtocol, - batch: RolloutBatchProtocol, - ) -> np.ndarray | BatchProtocol: + act: _TArrOrActBatch, + batch: ObsBatchProtocol, + ) -> _TArrOrActBatch: return self.policy.exploration_noise(act, batch) def set_eps(self, eps: float) -> None: diff --git a/tianshou/policy/modelfree/bdq.py b/tianshou/policy/modelfree/bdq.py index a91ea0093..ba3747772 100644 --- a/tianshou/policy/modelfree/bdq.py +++ b/tianshou/policy/modelfree/bdq.py @@ -8,6 +8,7 @@ from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch, to_torch_as from tianshou.data.batch import BatchProtocol from tianshou.data.types import ( + ActBatchProtocol, BatchWithReturnsProtocol, ModelOutputBatchProtocol, ObsBatchProtocol, @@ -182,11 +183,13 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TBDQN return BDQNTrainingStats(loss=loss.item()) # type: ignore[return-value] + _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") + def exploration_noise( self, - act: np.ndarray | BatchProtocol, - batch: RolloutBatchProtocol, - ) -> np.ndarray | BatchProtocol: + act: _TArrOrActBatch, + batch: ObsBatchProtocol, + ) -> _TArrOrActBatch: if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0): bsz = len(act) rand_mask = np.random.rand(bsz) < self.eps diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 1b371d4b3..b54860e6b 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -10,6 +10,7 @@ from tianshou.data import Batch, ReplayBuffer from tianshou.data.batch import BatchProtocol from tianshou.data.types import ( + ActBatchProtocol, ActStateBatchProtocol, BatchWithReturnsProtocol, ObsBatchProtocol, @@ -208,11 +209,13 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDDPG return DDPGTrainingStats(actor_loss=actor_loss.item(), critic_loss=critic_loss.item()) # type: ignore[return-value] + _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") + def exploration_noise( self, - act: np.ndarray | BatchProtocol, - batch: RolloutBatchProtocol, - ) -> np.ndarray | BatchProtocol: + act: _TArrOrActBatch, + batch: ObsBatchProtocol, + ) -> _TArrOrActBatch: if self._exploration_noise is None: return act if isinstance(act, np.ndarray): diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index b271cbd26..d1054f9f6 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -8,8 +8,7 @@ from torch.distributions import Categorical from tianshou.data import Batch, ReplayBuffer, to_torch -from tianshou.data.batch import BatchProtocol -from tianshou.data.types import ObsBatchProtocol, RolloutBatchProtocol +from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol from tianshou.policy import SACPolicy from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.sac import SACTrainingStats @@ -184,9 +183,11 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDisc alpha_loss=None if not self.is_auto_alpha else alpha_loss.item(), ) + _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") + def exploration_noise( self, - act: np.ndarray | BatchProtocol, - batch: RolloutBatchProtocol, - ) -> np.ndarray | BatchProtocol: + act: _TArrOrActBatch, + batch: ObsBatchProtocol, + ) -> _TArrOrActBatch: return act diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 5b90510b4..ad5f7dd0d 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -9,6 +9,7 @@ from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as from tianshou.data.batch import BatchProtocol from tianshou.data.types import ( + ActBatchProtocol, BatchWithReturnsProtocol, ModelOutputBatchProtocol, ObsBatchProtocol, @@ -232,11 +233,13 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDQNT return DQNTrainingStats(loss=loss.item()) # type: ignore[return-value] + _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") + def exploration_noise( self, - act: np.ndarray | BatchProtocol, - batch: RolloutBatchProtocol, - ) -> np.ndarray | BatchProtocol: + act: _TArrOrActBatch, + batch: ObsBatchProtocol, + ) -> _TArrOrActBatch: if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0): bsz = len(act) rand_mask = np.random.rand(bsz) < self.eps diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index e41e069d1..7a7be5888 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -1,11 +1,11 @@ -from typing import Any, Literal, Protocol, Self, cast, overload +from typing import Any, Literal, Protocol, Self, TypeVar, cast, overload import numpy as np from overrides import override from tianshou.data import Batch, ReplayBuffer from tianshou.data.batch import BatchProtocol, IndexType -from tianshou.data.types import RolloutBatchProtocol +from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol from tianshou.policy import BasePolicy from tianshou.policy.base import TLearningRateScheduler, TrainingStats @@ -160,16 +160,18 @@ def process_fn( # type: ignore buffer._meta.rew = save_rew return Batch(results) + _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") + def exploration_noise( self, - act: np.ndarray | BatchProtocol, - batch: RolloutBatchProtocol, - ) -> np.ndarray | BatchProtocol: + act: _TArrOrActBatch, + batch: ObsBatchProtocol, + ) -> _TArrOrActBatch: """Add exploration noise from sub-policy onto act.""" - assert isinstance( - batch.obs, - BatchProtocol, - ), f"here only observations of type Batch are permitted, but got {type(batch.obs)}" + if not isinstance(batch.obs, Batch): + raise TypeError( + f"here only observations of type Batch are permitted, but got {type(batch.obs)}", + ) for agent_id, policy in self.policies.items(): agent_index = np.nonzero(batch.obs.agent_id == agent_id)[0] if len(agent_index) == 0: @@ -223,7 +225,7 @@ def forward( # type: ignore results.append((False, np.array([-1]), Batch(), Batch(), Batch())) continue tmp_batch = batch[agent_index] - if isinstance(tmp_batch.rew, np.ndarray): + if "rew" in tmp_batch.keys() and isinstance(tmp_batch.rew, np.ndarray): # reward can be empty Batch (after initial reset) or nparray. tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent_id]] if not hasattr(tmp_batch.obs, "mask"): diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index c6d87d194..675112fae 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -237,7 +237,13 @@ def __init__( self.stop_fn_flag = False self.iter_num = 0 - def reset(self) -> None: + def _reset_collectors(self, reset_buffer: bool = False) -> None: + if self.train_collector is not None: + self.train_collector.reset(reset_buffer=reset_buffer) + if self.test_collector is not None: + self.test_collector.reset(reset_buffer=reset_buffer) + + def reset(self, reset_collectors: bool = True, reset_buffer: bool = False) -> None: """Initialize or reset the instance to yield a new iterator from zero.""" self.is_run = False self.env_step = 0 @@ -250,16 +256,18 @@ def reset(self) -> None: self.last_rew, self.last_len = 0.0, 0.0 self.start_time = time.time() - if self.train_collector is not None: - self.train_collector.reset_stat() - if self.train_collector.policy != self.policy or self.test_collector is None: - self.test_in_train = False + if reset_collectors: + self._reset_collectors(reset_buffer=reset_buffer) + + if self.train_collector is not None and ( + self.train_collector.policy != self.policy or self.test_collector is None + ): + self.test_in_train = False if self.test_collector is not None: assert self.episode_per_test is not None assert not isinstance(self.test_collector, AsyncCollector) # Issue 700 - self.test_collector.reset_stat() test_result = test_episode( self.policy, self.test_collector, @@ -284,7 +292,7 @@ def reset(self) -> None: self.iter_num = 0 def __iter__(self): # type: ignore - self.reset() + self.reset(reset_collectors=True, reset_buffer=False) return self def __next__(self) -> EpochStats: @@ -308,8 +316,8 @@ def __next__(self) -> EpochStats: # perform n step_per_epoch with progress(total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config) as t: + train_stat: CollectStatsBase while t.n < t.total and not self.stop_fn_flag: - train_stat: CollectStatsBase if self.train_collector is not None: train_stat, self.stop_fn_flag = self.train_step() pbar_data_dict = { @@ -515,12 +523,14 @@ def policy_update_fn( stats of the whole dataset """ - def run(self) -> InfoStats: + def run(self, reset_prior_to_run: bool = True) -> InfoStats: """Consume iterator. See itertools - recipes. Use functions that consume iterators at C speed (feed the entire iterator into a zero-length deque). """ + if reset_prior_to_run: + self.reset() try: self.is_run = True deque(self, maxlen=0) # feed the entire iterator into a zero-length deque diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index 300c7c470..7a96ea06f 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -26,8 +26,7 @@ def test_episode( reward_metric: Callable[[np.ndarray], np.ndarray] | None = None, ) -> CollectStats: """A simple wrapper of testing policy in collector.""" - collector.reset_env() - collector.reset_buffer() + collector.reset(reset_stats=False) policy.eval() if test_fn: test_fn(epoch, global_step)