Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/refactor collector #1063

Merged
merged 80 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
a1e3908
remove self.data, break async tests
bordeauxred Feb 23, 2024
72004a1
Remove preprocess_fn, adapt tests.
bordeauxred Feb 23, 2024
17d7e8a
Removing cur_rollout_batch. WIP
bordeauxred Mar 7, 2024
e9a3278
Renamings of vars in collect
Mar 7, 2024
c6a707e
minor rename
bordeauxred Mar 7, 2024
47bfa8c
Adjusted return type in BasePolicy.forward
Mar 7, 2024
a5b3601
More var renamings in collector
Mar 7, 2024
6363af8
Enhanced comments in collector
Mar 8, 2024
7b37eb1
Formatting
Mar 8, 2024
49d5648
Collector bugfixes: pass info to obs batch, don't use len on an int
Mar 8, 2024
bad9696
Collector bugfix: move breakout of the loop to after updating collect…
Mar 8, 2024
db5b9a3
Collector bugfix: missing () around walrus
Mar 8, 2024
d65d80d
Renaming of env and policy used in test, added docstrings
Mar 8, 2024
4641962
Tests: removed useless repetition of test_collector
Mar 8, 2024
7db9fca
Tests: removed more useless repetitions in test_collector
Mar 8, 2024
0282b81
Trainer: more control over buffer resetting, fixed call to iter
Mar 8, 2024
9add4f8
Fixes in tests - new Collector logic!
Mar 8, 2024
fdede35
Minor, aesthetic
Mar 8, 2024
7b041ad
Collector: fix persistence between collect iterations, remove reset f…
Mar 8, 2024
0fa5148
Fixed collector test by reinstating data mutation AFTER ADDING TO BUF…
Mar 8, 2024
900df81
Aesthetic
Mar 9, 2024
b2c43ec
Collector: fixed bug introduced in refactoring in persistence of _pre…
Mar 11, 2024
571ba45
Rename variable, fix mixup in collectory.collect docstring.
bordeauxred Mar 11, 2024
cf58c7c
Rename variables in AsyncCollector
bordeauxred Mar 12, 2024
32e8a34
Rename id to env_ids in venv; reformat
bordeauxred Mar 12, 2024
c84d459
fix typo and unnecessary mypy type ignore
bordeauxred Mar 13, 2024
c4a0bec
extract compute_action_policy_hidden in AsyncCollector
bordeauxred Mar 13, 2024
b3b3ae6
start replacing cur_data_batch, whole_batch
bordeauxred Mar 13, 2024
6c3614d
refactor async.collect to replace cur_rollout_batch, pull reset hidde…
bordeauxred Mar 14, 2024
4da2e13
add test to collect one episode in asyncCollector
bordeauxred Mar 14, 2024
877fbab
Make AsynCollector.collect work
bordeauxred Mar 15, 2024
5138321
Add option to reset_before_collect to AsycColl in n_episode case
bordeauxred Mar 15, 2024
6868ebf
Rewrite AsyncCollector tests
bordeauxred Mar 15, 2024
eb8119b
Rewrite AsyncCollector to not use self.current_transition_in_all_batc…
bordeauxred Mar 15, 2024
b7b9205
remove depreciated test.
bordeauxred Mar 15, 2024
0cda877
ractor collect of asyncCollector, replace self.data by explicit attri…
bordeauxred Mar 19, 2024
1fb64b5
factor out _compute_action_policy shared across Collector and AsyncCo…
bordeauxred Mar 19, 2024
2d5087d
formatting with poe
bordeauxred Mar 19, 2024
d879125
fix some pypy errors
bordeauxred Mar 20, 2024
f560164
adapt type annotation of reset to match step in venv
bordeauxred Mar 21, 2024
f28a6fe
fix some mypy issues
bordeauxred Mar 21, 2024
cbe9140
adapt types of reset in venv_wrappers to match step
bordeauxred Mar 21, 2024
59ecee5
fixing some more mypy issues
bordeauxred Mar 21, 2024
dc6c2a1
Fixed typing things by casting
Mar 21, 2024
ceedd19
fix poe issues, allow to collect fewer episodes than env_num
bordeauxred Mar 22, 2024
b53f18e
factor out get_values_at_indices_if_not_None
bordeauxred Mar 22, 2024
a31cdca
black
bordeauxred Mar 22, 2024
69f8b2b
adapt notebooks based upon doc-build
bordeauxred Mar 22, 2024
d78e920
fix inconsistent indent in docstring
bordeauxred Mar 22, 2024
a91b262
Typo
Mar 22, 2024
52d8a6d
Merge branch 'thuml-master' into feat/refactor_collector
Mar 22, 2024
aad2927
update tests to follow new reset interface
bordeauxred Mar 22, 2024
064088d
Docstring
Mar 22, 2024
d6cd083
Minor fixes in FiniteVectorEnv (created in tests)
Mar 22, 2024
4ebdd22
Make slicing function private
Mar 22, 2024
fb65b10
Added a function reproducing the hacky and wrong behavior of Batch fo…
Mar 22, 2024
da1f78a
Minor, comment
Mar 22, 2024
96824f0
fix multiagent test
bordeauxred Mar 22, 2024
10ac920
Merge remote-tracking branch 'aai-tian/feat/refactor_collector' into …
bordeauxred Mar 22, 2024
1fa6f9a
Fixed type in test
Mar 22, 2024
a94fa71
Fixes in high-level interfaces: reset collectors before collecting
Mar 22, 2024
3b90faf
Fixes in test: add explicit reset of collectors
Mar 22, 2024
729c60b
AsyncCollector: Stricter types and enforcement
Mar 22, 2024
48bd641
Fixes in tests, reset
bordeauxred Mar 25, 2024
39bc3cc
poe format
bordeauxred Mar 25, 2024
be8d7ef
Fixes in tests, int instead of np.int in Discrete.n
bordeauxred Mar 25, 2024
574c6e9
Use CollectStats with autogenerate_stats in AsyncCollector, put len t…
bordeauxred Mar 25, 2024
deb6f78
Added a method to converge a batch into dict or a list of dicts
Mar 25, 2024
abcb44c
Collector: fixed reset behavor for envpool envs
Mar 25, 2024
572dd68
Added explicit iter to batch, fixed mypy issues
Mar 25, 2024
8757d65
Simplified nullable_slice (no need for @overload)
Mar 25, 2024
09cb1c6
Simplified nullable_slice (no need for @overload)
Mar 25, 2024
bff8e7d
Fixed iter for empty batch
Mar 26, 2024
92a7467
Renamed env_ids back to env_id, added a TODO
Mar 26, 2024
0cde27b
Fixed iter in batch
Mar 26, 2024
3f305ab
Fixed test vectorenv
Mar 26, 2024
5470a6c
Protocols: removed accidental method implementation (copy-paste error)
Mar 26, 2024
c04df12
Removed runtime checkable of ObsBatchProtocol
Mar 26, 2024
7e77d08
Changelog [skip ci]
Mar 26, 2024
4f47f23
Refactor: replace overload in exploration_noise by generic
bordeauxred Mar 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,27 @@
# Changelog

## Release 1.1.0

### Api Extensions
- Batch received two new methods: `to_dict` and `to_list_of_dicts`. #1063
- `Collector`s can now be closed, and their reset is more granular. #1063
- Trainers can control whether collectors should be reset prior to training. #1063
- Convenience constructor for `CollectStats` called `with_autogenerated_stats`. #1063

### Internal Improvements
- `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063
- Introduced a first iteration of a naming convention for vars in `Collector`s. #1063
- Generally improved readability of Collector code and associated tests (still quite some way to go). #1063
- Improved typing for `exploration_noise` and within Collector. #1063

### Breaking Changes

- Removed `.data` attribute from `Collector` and its child classes. #1063
- Collectors no longer reset the environment on initialization. Instead, the user might have to call `reset`
expicitly or pass `reset_before_collect=True` . #1063
- VectorEnvs now return an array of info-dicts on reset instead of a list. #1063
- Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`. Can be considered a bugfix. #1063


Started after v1.0.0

2 changes: 1 addition & 1 deletion docs/02_notebooks/L0_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@
"source": [
"# Let's watch its performance!\n",
"policy.eval()\n",
"eval_result = test_collector.collect(n_episode=1, render=False)\n",
"eval_result = test_collector.collect(n_episode=3, render=False)\n",
"print(f\"Final reward: {eval_result.returns.mean()}, length: {eval_result.lens.mean()}\")"
]
},
Expand Down
5 changes: 2 additions & 3 deletions docs/02_notebooks/L5_Collector.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@
},
"outputs": [],
"source": [
"collect_result = test_collector.collect(n_episode=9)\n",
"collect_result = test_collector.collect(reset_before_collect=True, n_episode=9)\n",
"\n",
"collect_result.pprint_asdict()"
]
Expand All @@ -146,8 +146,7 @@
"outputs": [],
"source": [
"# Reset the collector\n",
"test_collector.reset()\n",
"collect_result = test_collector.collect(n_episode=9, random=True)\n",
"collect_result = test_collector.collect(reset_before_collect=True, n_episode=9, random=True)\n",
"\n",
"collect_result.pprint_asdict()"
]
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ ignore = [
"RET505",
"D106", # undocumented public nested class
"D205", # blank line after summary (prevents summary-only docstrings, which makes no sense)
"PLW2901", # overwrite vars in loop
]
unfixable = [
"F841", # unused variable. ruff keeps the call, but mostly we want to get rid of it all
Expand Down
24 changes: 20 additions & 4 deletions test/base/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,24 @@
from gymnasium.spaces import Box, Dict, Discrete, MultiDiscrete, Space, Tuple


class MyTestEnv(gym.Env):
"""A task for "going right". The task is to go right ``size`` steps."""
class MoveToRightEnv(gym.Env):
"""A task for "going right". The task is to go right ``size`` steps.

The observation is the current index, and the action is to go left or right.
Action 0 is to go left, and action 1 is to go right.
Taking action 0 at index 0 will keep the index at 0.
Arriving at index ``size`` means the task is done.
In the current implementation, stepping after the task is done is possible, which will
lead the index to be larger than ``size``.

Index 0 is the starting point. If reset is called with default options, the index will
be reset to 0.
"""

def __init__(
self,
size: int,
sleep: int = 0,
sleep: float = 0.0,
dict_state: bool = False,
recurse_state: bool = False,
ma_rew: int = 0,
Expand Down Expand Up @@ -74,8 +85,13 @@ def __init__(
def reset(
self,
seed: int | None = None,
# TODO: passing a dict here doesn't make any sense
options: dict[str, Any] | None = None,
) -> tuple[dict[str, Any] | np.ndarray, dict]:
""":param seed:
:param options: the start index is provided in options["state"]
:return:
"""
if options is None:
options = {"state": 0}
super().reset(seed=seed)
Expand Down Expand Up @@ -188,7 +204,7 @@ def step(
return self._encode_obs(), 1.0, False, False, {}


class MyGoalEnv(MyTestEnv):
class MyGoalEnv(MoveToRightEnv):
def __init__(self, *args: Any, **kwargs: Any) -> None:
assert (
kwargs.get("dict_state", 0) + kwargs.get("recurse_state", 0) == 0
Expand Down
12 changes: 6 additions & 6 deletions test/base/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
from tianshou.data.utils.converter import to_hdf5

if __name__ == "__main__":
from env import MyGoalEnv, MyTestEnv
from env import MoveToRightEnv, MyGoalEnv
else: # pytest
from test.base.env import MyGoalEnv, MyTestEnv
from test.base.env import MoveToRightEnv, MyGoalEnv


def test_replaybuffer(size=10, bufsize=20) -> None:
env = MyTestEnv(size)
env = MoveToRightEnv(size)
buf = ReplayBuffer(bufsize)
buf.update(buf)
assert str(buf) == buf.__class__.__name__ + "()"
Expand Down Expand Up @@ -209,7 +209,7 @@ def test_ignore_obs_next(size=10) -> None:


def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3) -> None:
env = MyTestEnv(size)
env = MoveToRightEnv(size)
buf = ReplayBuffer(bufsize, stack_num=stack_num)
buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True)
buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True)
Expand Down Expand Up @@ -280,7 +280,7 @@ def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3) -> None:


def test_priortized_replaybuffer(size=32, bufsize=15) -> None:
env = MyTestEnv(size)
env = MoveToRightEnv(size)
buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5)
buf2 = PrioritizedVectorReplayBuffer(bufsize, buffer_num=3, alpha=0.5, beta=0.5)
obs, info = env.reset()
Expand Down Expand Up @@ -1028,7 +1028,7 @@ def test_multibuf_stack() -> None:
bufsize = 9
stack_num = 4
cached_num = 3
env = MyTestEnv(size)
env = MoveToRightEnv(size)
# test if CachedReplayBuffer can handle stack_num + ignore_obs_next
buf4 = CachedReplayBuffer(
ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True),
Expand Down
Loading
Loading