From 6f46f2d5d7f4dbb9210a5cc59f9ab7eedb4f8733 Mon Sep 17 00:00:00 2001 From: bordeauxred <2robert.mueller@gmail.com> Date: Thu, 28 Mar 2024 18:02:31 +0100 Subject: [PATCH] Feat/refactor collector (#1063) Closes: #1058 ### Api Extensions - Batch received two new methods: `to_dict` and `to_list_of_dicts`. #1063 - `Collector`s can now be closed, and their reset is more granular. #1063 - Trainers can control whether collectors should be reset prior to training. #1063 - Convenience constructor for `CollectStats` called `with_autogenerated_stats`. #1063 ### Internal Improvements - `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063 - Introduced a first iteration of a naming convention for vars in `Collector`s. #1063 - Generally improved readability of Collector code and associated tests (still quite some way to go). #1063 - Improved typing for `exploration_noise` and within Collector. #1063 ### Breaking Changes - Removed `.data` attribute from `Collector` and its child classes. #1063 - Collectors no longer reset the environment on initialization. Instead, the user might have to call `reset` expicitly or pass `reset_before_collect=True` . #1063 - VectorEnvs now return an array of info-dicts on reset instead of a list. #1063 - Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`. Can be considered a bugfix. #1063 --------- Co-authored-by: Michael Panchenko --- CHANGELOG.md | 23 +++ test/base/test_buffer.py | 6 +- test/base/test_collector.py | 268 ++++++++++++++++------------------- test/base/test_env.py | 2 +- test/base/test_env_finite.py | 17 +-- test/modelbased/test_psrl.py | 5 +- 6 files changed, 157 insertions(+), 164 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e0ac65c5..5a37acb17 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,27 @@ # Changelog +## Release 1.1.0 + +### Api Extensions +- Batch received two new methods: `to_dict` and `to_list_of_dicts`. #1063 +- `Collector`s can now be closed, and their reset is more granular. #1063 +- Trainers can control whether collectors should be reset prior to training. #1063 +- Convenience constructor for `CollectStats` called `with_autogenerated_stats`. #1063 + +### Internal Improvements +- `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063 +- Introduced a first iteration of a naming convention for vars in `Collector`s. #1063 +- Generally improved readability of Collector code and associated tests (still quite some way to go). #1063 +- Improved typing for `exploration_noise` and within Collector. #1063 + +### Breaking Changes + +- Removed `.data` attribute from `Collector` and its child classes. #1063 +- Collectors no longer reset the environment on initialization. Instead, the user might have to call `reset` +expicitly or pass `reset_before_collect=True` . #1063 +- VectorEnvs now return an array of info-dicts on reset instead of a list. #1063 +- Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`. Can be considered a bugfix. #1063 + + Started after v1.0.0 diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 31265f664..d17596a75 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -28,7 +28,7 @@ from test.base.env import MoveToRightEnv, MyGoalEnv -def test_replaybuffer(size: int = 10, bufsize: int = 20) -> None: +def test_replaybuffer(size=10, bufsize=20) -> None: env = MoveToRightEnv(size) buf = ReplayBuffer(bufsize) buf.update(buf) @@ -218,7 +218,7 @@ def test_ignore_obs_next(size: int = 10) -> None: assert data.obs_next -def test_stack(size: int = 5, bufsize: int = 9, stack_num: int = 4, cached_num: int = 3) -> None: +def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3) -> None: env = MoveToRightEnv(size) buf = ReplayBuffer(bufsize, stack_num=stack_num) buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True) @@ -289,7 +289,7 @@ def test_stack(size: int = 5, bufsize: int = 9, stack_num: int = 4, cached_num: buf[bufsize * 2] -def test_priortized_replaybuffer(size: int = 32, bufsize: int = 15) -> None: +def test_priortized_replaybuffer(size=32, bufsize=15) -> None: env = MoveToRightEnv(size) buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5) buf2 = PrioritizedVectorReplayBuffer(bufsize, buffer_num=3, alpha=0.5, beta=0.5) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 6baa6abf3..2718b10e6 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -77,8 +77,8 @@ def forward( action_shape = self.action_shape if self.action_shape else len(batch.obs) return Batch(act=np.ones(action_shape), state=state) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TrainingStats: - raise NotImplementedError + def learn(self): + pass def test_collector() -> None: @@ -107,9 +107,7 @@ def test_collector() -> None: # Making one more step results in obs_next=1 # The final 0 in the buffer.obs is because the buffer is initialized with zeros and the direct attr access assert np.allclose(c_single_env.buffer.obs[:4, 0], [0, 1, 0, 0]) - obs_next = c_single_env.buffer[:].obs_next[..., 0] - assert isinstance(obs_next, np.ndarray) - assert np.allclose(obs_next, [1, 2, 1]) + assert np.allclose(c_single_env.buffer[:].obs_next[..., 0], [1, 2, 1]) keys = np.zeros(100) keys[:3] = 1 assert np.allclose(c_single_env.buffer.info["key"], keys) @@ -129,9 +127,7 @@ def test_collector() -> None: c_single_env.collect(n_episode=3) assert len(c_single_env.buffer) == 8 assert np.allclose(c_single_env.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0]) - obs_next = c_single_env.buffer[:].obs_next[..., 0] - assert isinstance(obs_next, np.ndarray) - assert np.allclose(obs_next, [1, 2, 1, 2, 1, 2, 1, 2]) + assert np.allclose(c_single_env.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2]) assert np.allclose(c_single_env.buffer.info["key"][:8], 1) for e in c_single_env.buffer.info["env"][:8]: assert isinstance(e, MoveToRightEnv) @@ -152,9 +148,7 @@ def test_collector() -> None: valid_indices = [0, 1, 25, 26, 50, 51, 75, 76] obs[valid_indices] = [0, 1, 0, 1, 0, 1, 0, 1] assert np.allclose(c_subproc_venv_4_envs.buffer.obs[:, 0], obs) - obs_next = c_subproc_venv_4_envs.buffer[:].obs_next[..., 0] - assert isinstance(obs_next, np.ndarray) - assert np.allclose(obs_next, [1, 2, 1, 2, 1, 2, 1, 2]) + assert np.allclose(c_subproc_venv_4_envs.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2]) keys = np.zeros(100) keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1] assert np.allclose(c_subproc_venv_4_envs.buffer.info["key"], keys) @@ -176,10 +170,8 @@ def test_collector() -> None: valid_indices = [2, 3, 27, 52, 53, 77, 78, 79] obs[valid_indices] = [0, 1, 2, 2, 3, 2, 3, 4] assert np.allclose(c_subproc_venv_4_envs.buffer.obs[:, 0], obs) - obs_next = c_subproc_venv_4_envs.buffer[:].obs_next[..., 0] - assert isinstance(obs_next, np.ndarray) assert np.allclose( - obs_next, + c_subproc_venv_4_envs.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5], ) keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1] @@ -229,12 +221,9 @@ def test_collector() -> None: with pytest.raises(TypeError): c_dummy_venv_4_envs.collect() - def get_env_factory(i: int, t: str) -> Callable[[], NXEnv]: - return lambda: NXEnv(i, t) - # test NXEnv for obs_type in ["array", "object"]: - envs = SubprocVectorEnv([get_env_factory(i=i, t=obs_type) for i in [5, 10, 15, 20]]) + envs = SubprocVectorEnv([lambda i=x, t=obs_type: NXEnv(i, t) for x in [5, 10, 15, 20]]) c_suproc_new = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4)) c_suproc_new.reset() c_suproc_new.collect(n_step=6) @@ -242,7 +231,7 @@ def get_env_factory(i: int, t: str) -> Callable[[], NXEnv]: @pytest.fixture() -def async_collector_and_env_lens() -> tuple[AsyncCollector, list[int]]: +def get_AsyncCollector(): env_lens = [2, 3, 4, 5] env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0.001, random_sleep=True) for i in env_lens] @@ -254,43 +243,34 @@ def async_collector_and_env_lens() -> tuple[AsyncCollector, list[int]]: venv, VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4), ) - async_collector.reset() - return async_collector, env_lens + c1.reset() + return c1, env_lens class TestAsyncCollector: - def test_collect_without_argument_gives_error( - self, - async_collector_and_env_lens: tuple[AsyncCollector, list[int]], - ) -> None: - c1, env_lens = async_collector_and_env_lens + def test_collect_without_argument_gives_error(self, get_AsyncCollector): + c1, env_lens = get_AsyncCollector with pytest.raises(TypeError): c1.collect() - def test_collect_one_episode_async( - self, - async_collector_and_env_lens: tuple[AsyncCollector, list[int]], - ) -> None: - c1, env_lens = async_collector_and_env_lens + def test_collect_one_episode_async(self, get_AsyncCollector): + c1, env_lens = get_AsyncCollector result = c1.collect(n_episode=1) assert result.n_collected_episodes >= 1 def test_enough_episodes_two_collection_cycles_n_episode_without_reset( self, - async_collector_and_env_lens: tuple[AsyncCollector, list[int]], - ) -> None: - c1, env_lens = async_collector_and_env_lens + get_AsyncCollector, + ): + c1, env_lens = get_AsyncCollector n_episode = 2 result_c1 = c1.collect(n_episode=n_episode, reset_before_collect=False) assert result_c1.n_collected_episodes >= n_episode result_c2 = c1.collect(n_episode=n_episode, reset_before_collect=False) assert result_c2.n_collected_episodes >= n_episode - def test_enough_episodes_two_collection_cycles_n_episode_with_reset( - self, - async_collector_and_env_lens: tuple[AsyncCollector, list[int]], - ) -> None: - c1, env_lens = async_collector_and_env_lens + def test_enough_episodes_two_collection_cycles_n_episode_with_reset(self, get_AsyncCollector): + c1, env_lens = get_AsyncCollector n_episode = 2 result_c1 = c1.collect(n_episode=n_episode, reset_before_collect=True) assert result_c1.n_collected_episodes >= n_episode @@ -299,9 +279,9 @@ def test_enough_episodes_two_collection_cycles_n_episode_with_reset( def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collection_cycles_n_episode( self, - async_collector_and_env_lens: tuple[AsyncCollector, list[int]], - ) -> None: - c1, env_lens = async_collector_and_env_lens + get_AsyncCollector, + ): + c1, env_lens = get_AsyncCollector ptr = [0, 0, 0, 0] bufsize = 60 for n_episode in tqdm.trange(1, 30, desc="test async n_episode"): @@ -321,9 +301,9 @@ def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collecti def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collection_cycles_n_step( self, - async_collector_and_env_lens: tuple[AsyncCollector, list[int]], - ) -> None: - c1, env_lens = async_collector_and_env_lens + get_AsyncCollector, + ): + c1, env_lens = get_AsyncCollector bufsize = 60 ptr = [0, 0, 0, 0] for n_step in tqdm.trange(1, 15, desc="test async n_step"): @@ -340,15 +320,17 @@ def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collecti assert np.all(buf.obs[indices].reshape(count, env_len) == seq) assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1) + @pytest.mark.parametrize("gym_reset_kwargs", [None, {}]) def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collection_cycles_first_n_episode_then_n_step( self, - async_collector_and_env_lens: tuple[AsyncCollector, list[int]], - ) -> None: - c1, env_lens = async_collector_and_env_lens + get_AsyncCollector, + gym_reset_kwargs, + ): + c1, env_lens = get_AsyncCollector bufsize = 60 ptr = [0, 0, 0, 0] for n_episode in tqdm.trange(1, 30, desc="test async n_episode"): - result = c1.collect(n_episode=n_episode) + result = c1.collect(n_episode=n_episode, gym_reset_kwargs=gym_reset_kwargs) assert result.n_collected_episodes >= n_episode # check buffer data, obs and obs_next, env_id for i, count in enumerate(np.bincount(result.lens, minlength=6)[2:]): @@ -363,7 +345,7 @@ def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collecti assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1) # test async n_step, for now the buffer should be full of data, thus no bincount stuff as above for n_step in tqdm.trange(1, 15, desc="test async n_step"): - result = c1.collect(n_step=n_step) + result = c1.collect(n_step=n_step, gym_reset_kwargs=gym_reset_kwargs) assert result.n_collected_steps >= n_step for i in range(4): env_len = i + 2 @@ -549,100 +531,96 @@ def test_collector_with_multi_agent() -> None: c_single_env.buffer.update(c_multi_env_ma.buffer) assert len(c_single_env.buffer) in [42, 43] if len(c_single_env.buffer) == 42: - multi_env_returns = np.array( - [ - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1, - 0, - 1, - 0, - 1, - 0, - 1, - 0, - 1, - 0, - 0, - 1, - 0, - 0, - 1, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1, - ], - ) + multi_env_returns = [ + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + ] else: - multi_env_returns = np.array( - [ - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1, - 0, - 1, - 0, - 1, - 0, - 1, - 0, - 0, - 1, - 0, - 0, - 1, - 0, - 0, - 1, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1, - ], - ) + multi_env_returns = [ + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + ] assert np.all(c_single_env.buffer[:].rew == [[x] * 4 for x in multi_env_returns]) assert np.all(c_single_env.buffer[:].done == multi_env_returns) c2 = Collector( @@ -924,7 +902,7 @@ def test_collector_envpool_gym_reset_return_info() -> None: assert np.allclose(c0.buffer.info["env_id"], env_ids) -def test_collector_with_vector_env() -> None: +def test_collector_with_vector_env(): env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0) for i in [1, 8, 9, 10]] dum = DummyVectorEnv(env_fns) @@ -948,7 +926,7 @@ def test_collector_with_vector_env() -> None: assert np.array_equal(np.array([1, 1, 1, 8, 1, 9, 1, 10]), c4r.lens) -def test_async_collector_with_vector_env() -> None: +def test_async_collector_with_vector_env(): env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0) for i in [1, 8, 9, 10]] dum = DummyVectorEnv(env_fns) diff --git a/test/base/test_env.py b/test/base/test_env.py index a476ec5a9..d6f6ef04c 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -247,7 +247,7 @@ def create_env(i: int, t: str) -> Callable[[], NXEnv]: assert obs.dtype == object -def test_env_reset_optional_kwargs(size: int = 10000, num: int = 8) -> None: +def test_env_reset_optional_kwargs(size=10000, num=8) -> None: env_fns = [lambda i=i: MoveToRightEnv(size=i) for i in range(size, size + num)] test_cls = [DummyVectorEnv, SubprocVectorEnv, ShmemVectorEnv] if has_ray(): diff --git a/test/base/test_env_finite.py b/test/base/test_env_finite.py index 657100554..9142e6a6f 100644 --- a/test/base/test_env_finite.py +++ b/test/base/test_env_finite.py @@ -102,24 +102,20 @@ def _get_default_info(self) -> dict | None: # END - def reset( - self, - env_id: int | list[int] | np.ndarray | None = None, - **kwargs: Any, - ) -> tuple[np.ndarray, np.ndarray]: + def reset(self, env_id=None): env_id = self._wrap_id(env_id) self._reset_alive_envs() # ask super to reset alive envs and remap to current index request_id = list(filter(lambda i: i in self._alive_env_ids, env_id)) - obs_list: list[np.ndarray | None] = [None] * len(env_id) - infos: list[dict | None] = [None] * len(env_id) + obs = [None] * len(env_id) + infos = [None] * len(env_id) id2idx = {i: k for k, i in enumerate(env_id)} if request_id: for k, o, info in zip(request_id, *super().reset(request_id), strict=True): obs_list[id2idx[k]] = o infos[id2idx[k]] = info - for i, o in zip(env_id, obs_list, strict=True): + for i, o in zip(env_id, obs, strict=True): if o is None and i in self._alive_env_ids: self._alive_env_ids.remove(i) @@ -137,10 +133,7 @@ def reset( self.reset() raise StopIteration - obs_list = cast(list[np.ndarray], obs_list) - infos = cast(list[dict], infos) - - return np.stack(obs_list), np.array(infos) + return np.stack(obs), np.array(infos) def step( self, diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index 72742b785..2994b11dd 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -122,9 +122,8 @@ def stop_fn(mean_rewards: float) -> bool: # Let's watch its performance! policy.eval() test_envs.seed(args.seed) - test_collector.reset() - stats = test_collector.collect(n_episode=args.test_num, render=args.render) - stats.pprint_asdict() + result = test_collector.collect(n_episode=args.test_num, render=args.render) + print(f"Final reward: {result.rew_mean}, length: {result.len_mean}") elif env.spec.reward_threshold: assert result.best_reward >= env.spec.reward_threshold