Skip to content

Commit

Permalink
Feat/refactor collector (#1063)
Browse files Browse the repository at this point in the history
Closes: #1058 

### Api Extensions
- Batch received two new methods: `to_dict` and `to_list_of_dicts`.
#1063
- `Collector`s can now be closed, and their reset is more granular.
#1063
- Trainers can control whether collectors should be reset prior to
training. #1063
- Convenience constructor for `CollectStats` called
`with_autogenerated_stats`. #1063

### Internal Improvements
- `Collector`s rely less on state, the few stateful things are stored
explicitly instead of through a `.data` attribute. #1063
- Introduced a first iteration of a naming convention for vars in
`Collector`s. #1063
- Generally improved readability of Collector code and associated tests
(still quite some way to go). #1063
- Improved typing for `exploration_noise` and within Collector. #1063

### Breaking Changes

- Removed `.data` attribute from `Collector` and its child classes.
#1063
- Collectors no longer reset the environment on initialization. Instead,
the user might have to call `reset`
expicitly or pass `reset_before_collect=True` . #1063
- VectorEnvs now return an array of info-dicts on reset instead of a
list. #1063
- Fixed `iter(Batch(...)` which now behaves the same way as
`Batch(...).__iter__()`. Can be considered a bugfix. #1063

---------

Co-authored-by: Michael Panchenko <[email protected]>
  • Loading branch information
bordeauxred and Michael Panchenko authored Mar 28, 2024
1 parent edae9e4 commit 4f65b13
Show file tree
Hide file tree
Showing 44 changed files with 1,152 additions and 642 deletions.
23 changes: 23 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,27 @@
# Changelog

## Release 1.1.0

### Api Extensions
- Batch received two new methods: `to_dict` and `to_list_of_dicts`. #1063
- `Collector`s can now be closed, and their reset is more granular. #1063
- Trainers can control whether collectors should be reset prior to training. #1063
- Convenience constructor for `CollectStats` called `with_autogenerated_stats`. #1063

### Internal Improvements
- `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063
- Introduced a first iteration of a naming convention for vars in `Collector`s. #1063
- Generally improved readability of Collector code and associated tests (still quite some way to go). #1063
- Improved typing for `exploration_noise` and within Collector. #1063

### Breaking Changes

- Removed `.data` attribute from `Collector` and its child classes. #1063
- Collectors no longer reset the environment on initialization. Instead, the user might have to call `reset`
expicitly or pass `reset_before_collect=True` . #1063
- VectorEnvs now return an array of info-dicts on reset instead of a list. #1063
- Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`. Can be considered a bugfix. #1063


Started after v1.0.0

2 changes: 1 addition & 1 deletion docs/02_notebooks/L0_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@
"source": [
"# Let's watch its performance!\n",
"policy.eval()\n",
"eval_result = test_collector.collect(n_episode=1, render=False)\n",
"eval_result = test_collector.collect(n_episode=3, render=False)\n",
"print(f\"Final reward: {eval_result.returns.mean()}, length: {eval_result.lens.mean()}\")"
]
},
Expand Down
5 changes: 2 additions & 3 deletions docs/02_notebooks/L5_Collector.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@
},
"outputs": [],
"source": [
"collect_result = test_collector.collect(n_episode=9)\n",
"collect_result = test_collector.collect(reset_before_collect=True, n_episode=9)\n",
"\n",
"collect_result.pprint_asdict()"
]
Expand All @@ -146,8 +146,7 @@
"outputs": [],
"source": [
"# Reset the collector\n",
"test_collector.reset()\n",
"collect_result = test_collector.collect(n_episode=9, random=True)\n",
"collect_result = test_collector.collect(reset_before_collect=True, n_episode=9, random=True)\n",
"\n",
"collect_result.pprint_asdict()"
]
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ ignore = [
"RET505",
"D106", # undocumented public nested class
"D205", # blank line after summary (prevents summary-only docstrings, which makes no sense)
"PLW2901", # overwrite vars in loop
]
unfixable = [
"F841", # unused variable. ruff keeps the call, but mostly we want to get rid of it all
Expand Down
24 changes: 20 additions & 4 deletions test/base/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,24 @@
from gymnasium.spaces import Box, Dict, Discrete, MultiDiscrete, Space, Tuple


class MyTestEnv(gym.Env):
"""A task for "going right". The task is to go right ``size`` steps."""
class MoveToRightEnv(gym.Env):
"""A task for "going right". The task is to go right ``size`` steps.
The observation is the current index, and the action is to go left or right.
Action 0 is to go left, and action 1 is to go right.
Taking action 0 at index 0 will keep the index at 0.
Arriving at index ``size`` means the task is done.
In the current implementation, stepping after the task is done is possible, which will
lead the index to be larger than ``size``.
Index 0 is the starting point. If reset is called with default options, the index will
be reset to 0.
"""

def __init__(
self,
size: int,
sleep: int = 0,
sleep: float = 0.0,
dict_state: bool = False,
recurse_state: bool = False,
ma_rew: int = 0,
Expand Down Expand Up @@ -74,8 +85,13 @@ def __init__(
def reset(
self,
seed: int | None = None,
# TODO: passing a dict here doesn't make any sense
options: dict[str, Any] | None = None,
) -> tuple[dict[str, Any] | np.ndarray, dict]:
""":param seed:
:param options: the start index is provided in options["state"]
:return:
"""
if options is None:
options = {"state": 0}
super().reset(seed=seed)
Expand Down Expand Up @@ -188,7 +204,7 @@ def step(
return self._encode_obs(), 1.0, False, False, {}


class MyGoalEnv(MyTestEnv):
class MyGoalEnv(MoveToRightEnv):
def __init__(self, *args: Any, **kwargs: Any) -> None:
assert (
kwargs.get("dict_state", 0) + kwargs.get("recurse_state", 0) == 0
Expand Down
12 changes: 6 additions & 6 deletions test/base/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
from tianshou.data.utils.converter import to_hdf5

if __name__ == "__main__":
from env import MyGoalEnv, MyTestEnv
from env import MoveToRightEnv, MyGoalEnv
else: # pytest
from test.base.env import MyGoalEnv, MyTestEnv
from test.base.env import MoveToRightEnv, MyGoalEnv


def test_replaybuffer(size=10, bufsize=20) -> None:
env = MyTestEnv(size)
env = MoveToRightEnv(size)
buf = ReplayBuffer(bufsize)
buf.update(buf)
assert str(buf) == buf.__class__.__name__ + "()"
Expand Down Expand Up @@ -209,7 +209,7 @@ def test_ignore_obs_next(size=10) -> None:


def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3) -> None:
env = MyTestEnv(size)
env = MoveToRightEnv(size)
buf = ReplayBuffer(bufsize, stack_num=stack_num)
buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True)
buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True)
Expand Down Expand Up @@ -280,7 +280,7 @@ def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3) -> None:


def test_priortized_replaybuffer(size=32, bufsize=15) -> None:
env = MyTestEnv(size)
env = MoveToRightEnv(size)
buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5)
buf2 = PrioritizedVectorReplayBuffer(bufsize, buffer_num=3, alpha=0.5, beta=0.5)
obs, info = env.reset()
Expand Down Expand Up @@ -1028,7 +1028,7 @@ def test_multibuf_stack() -> None:
bufsize = 9
stack_num = 4
cached_num = 3
env = MyTestEnv(size)
env = MoveToRightEnv(size)
# test if CachedReplayBuffer can handle stack_num + ignore_obs_next
buf4 = CachedReplayBuffer(
ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True),
Expand Down
Loading

0 comments on commit 4f65b13

Please sign in to comment.