Skip to content

Commit

Permalink
Moved subbuffer-related functionality from Collector to Buffer (#1214)
Browse files Browse the repository at this point in the history
Internal improvements post #1196
  • Loading branch information
MischaPanch authored Sep 2, 2024
1 parent 36dd21a commit 16f2fc2
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 132 deletions.
8 changes: 4 additions & 4 deletions docs/02_notebooks/L6_Trainer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@
"outputs": [],
"source": [
"train_env_num = 4\n",
"buffer_size = (\n",
" 2000 # Since REINFORCE is an on-policy algorithm, we don't need a very large buffer size\n",
")\n",
"# Since REINFORCE is an on-policy algorithm, we don't need a very large buffer size\n",
"buffer_size = 2000\n",
"\n",
"\n",
"# Create the environments, used for training and evaluation\n",
"env = gym.make(\"CartPole-v1\")\n",
Expand Down Expand Up @@ -275,7 +275,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
"version": "3.11.4"
}
},
"nbformat": 4,
Expand Down
2 changes: 2 additions & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -290,3 +290,5 @@ subclass
subclassing
dist
dists
subbuffer
subbuffers
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ ignore = [
"PLW2901", # overwrite vars in loop
"B027", # empty and non-abstract method in abstract class
"D404", # It's fine to start with "This" in docstrings
"D407", "D408", "D409", # Ruff rules for underlines under 'Example:' and so clash with Sphinx
]
unfixable = [
"F841", # unused variable. ruff keeps the call, but mostly we want to get rid of it all
Expand Down
44 changes: 44 additions & 0 deletions test/base/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1503,3 +1503,47 @@ def test_buffer_dropnull() -> None:
buf.dropnull()
assert len(buf[:3]) == 3
assert not buf.hasnull()


@pytest.fixture
def dummy_rollout_batch() -> RolloutBatchProtocol:
return cast(
RolloutBatchProtocol,
Batch(
obs=np.arange(2),
obs_next=np.arange(2),
act=np.arange(5),
rew=1,
terminated=False,
truncated=False,
done=False,
info={},
),
)


def test_get_replay_buffer_indices(dummy_rollout_batch: RolloutBatchProtocol) -> None:
buffer = ReplayBuffer(5)
for _ in range(5):
buffer.add(dummy_rollout_batch)
assert np.array_equal(buffer.get_buffer_indices(0, 3), [0, 1, 2])
assert np.array_equal(buffer.get_buffer_indices(3, 2), [3, 4, 0, 1])
assert np.array_equal(buffer.get_buffer_indices(0, 5), np.arange(5))


def test_get_vector_replay_buffer_indices(dummy_rollout_batch: RolloutBatchProtocol) -> None:
stacked_batch = Batch.stack([dummy_rollout_batch, dummy_rollout_batch])
buffer = VectorReplayBuffer(10, 2)
for _ in range(5):
buffer.add(stacked_batch)

assert np.array_equal(buffer.get_buffer_indices(0, 3), [0, 1, 2])
assert np.array_equal(buffer.get_buffer_indices(3, 2), [3, 4, 0, 1])

assert np.array_equal(buffer.get_buffer_indices(6, 9), [6, 7, 8])
assert np.array_equal(buffer.get_buffer_indices(8, 7), [8, 9, 5, 6])

with pytest.raises(ValueError):
buffer.get_buffer_indices(3, 6)
with pytest.raises(ValueError):
buffer.get_buffer_indices(6, 3)
170 changes: 155 additions & 15 deletions tianshou/data/buffer/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Self, TypeVar, cast
from collections.abc import Sequence
from typing import Any, ClassVar, Self, TypeVar, cast

import h5py
import numpy as np
Expand Down Expand Up @@ -60,6 +61,14 @@ class ReplayBuffer:
"info",
"policy",
)
_required_keys_for_add: ClassVar[set[str]] = {
"obs",
"act",
"rew",
"terminated",
"truncated",
"done",
}

def __init__(
self,
Expand Down Expand Up @@ -103,6 +112,111 @@ def subbuffer_edges(self) -> np.ndarray:
"""
return np.array([0, self.maxsize], dtype=int)

def _get_start_stop_tuples_for_edge_crossing_interval(
self,
start: int,
stop: int,
) -> tuple[tuple[int, int], tuple[int, int]]:
"""Assumes that stop < start and retrieves tuples corresponding to the two
slices that determine the interval within the buffer.
Example:
-------
>>> list(self.subbuffer_edges) == [0, 5, 10]
>>> start = 4
>>> stop = 2
>>> self._get_start_stop_tuples_for_edge_crossing_interval(start, stop)
((4, 5), (0, 2))
The buffer sliced from 4 to 5 and then from 0 to 2 will contain the transitions
corresponding to the provided start and stop values.
"""
if stop >= start:
raise ValueError(
f"Expected stop < start, but got {start=}, {stop=}. "
f"For stop larger than start this method should never be called, "
f"and stop=start should never occur. This can occur either due to an implementation error, "
f"or due a bad configuration of the buffer that resulted in a single episode being so long that "
f"it completely filled a subbuffer (of size len(buffer)/degree_of_vectorization). "
f"Consider either shortening the episode, increasing the size of the buffer, or decreasing the "
f"degree of vectorization.",
)
subbuffer_edges = cast(Sequence[int], self.subbuffer_edges)

edge_after_start_idx = int(np.searchsorted(subbuffer_edges, start, side="left"))
"""This is the crossed edge"""

if edge_after_start_idx == 0:
raise ValueError(
f"The start value should be larger than the first edge, but got {start=}, {subbuffer_edges[1]=}.",
)
edge_after_start = subbuffer_edges[edge_after_start_idx]
edge_before_stop = subbuffer_edges[edge_after_start_idx - 1]
"""It's the edge before the crossed edge"""

if edge_before_stop >= stop:
raise ValueError(
f"The edge before the crossed edge should be smaller than the stop, but got {edge_before_stop=}, {stop=}.",
)
return (start, edge_after_start), (edge_before_stop, stop)

def get_buffer_indices(self, start: int, stop: int) -> np.ndarray:
"""Get the indices of the transitions in the buffer between start and stop.
The special thing about this is that stop may actually be smaller than start,
since one often is interested in a sequence of transitions that goes over a subbuffer edge.
The main use case for this method is to retrieve an episode from the buffer, in which case
start is the index of the first transition in the episode and stop is the index where `done` is True + 1.
This can be done with the following code:
.. code-block:: python
episode_indices = buffer.get_buffer_indices(episode_start_index, episode_done_index + 1)
episode = buffer[episode_indices]
Even when `start` is smaller than `stop`, it will be validated that they are in the same subbuffer.
Example:
--------
>>> list(buffer.subbuffer_edges) == [0, 5, 10]
>>> buffer.get_buffer_indices(start=2, stop=4)
[2, 3]
>>> buffer.get_buffer_indices(start=4, stop=2)
[4, 0, 1]
>>> buffer.get_buffer_indices(start=8, stop=7)
[8, 9, 5, 6]
>>> buffer.get_buffer_indices(start=1, stop=6)
ValueError: Start and stop indices must be within the same subbuffer.
>>> buffer.get_buffer_indices(start=8, stop=1)
ValueError: Start and stop indices must be within the same subbuffer.
:param start: The start index of the interval.
:param stop: The stop index of the interval.
:return: The indices of the transitions in the buffer between start and stop.
"""
start_left_edge = np.searchsorted(self.subbuffer_edges, start, side="right") - 1
stop_left_edge = np.searchsorted(self.subbuffer_edges, stop - 1, side="right") - 1
if start_left_edge != stop_left_edge:
raise ValueError(
f"Start and stop indices must be within the same subbuffer. "
f"Got {start=} in subbuffer edge {start_left_edge} and {stop=} in subbuffer edge {stop_left_edge}.",
)
if stop > start:
return np.arange(start, stop, dtype=int)
else:
(start, upper_edge), (
lower_edge,
stop,
) = self._get_start_stop_tuples_for_edge_crossing_interval(
start,
stop,
)
log.debug(f"{start=}, {upper_edge=}, {lower_edge=}, {stop=}")
return np.concatenate(
(np.arange(start, upper_edge, dtype=int), np.arange(lower_edge, stop, dtype=int)),
)

def __len__(self) -> int:
return self._size

Expand Down Expand Up @@ -297,43 +411,69 @@ def add(
:param batch: the input data batch. "obs", "act", "rew",
"terminated", "truncated" are required keys.
:param buffer_ids: to make consistent with other buffer's add function; if it
is not None, we assume the input batch's first dimension is always 1.
:param buffer_ids: id's of subbuffers, allowed here to be consistent with classes similar to
:class:`~tianshou.data.buffer.vecbuf.VectorReplayBuffer`. Since the `ReplayBuffer`
has a single subbuffer, if this is not None, it must be a single element with value 0.
In that case, the batch is expected to have the shape (1, len(data)).
Failure to adhere to this will result in a `ValueError`.
Return (current_index, episode_return, episode_length, episode_start_index). If
Return `(current_index, episode_return, episode_length, episode_start_index)`. If
the episode is not finished, the return value of episode_length and
episode_reward is 0.
"""
# preprocess batch
# preprocess and copy batch into a new Batch object to avoid mutating the input
# TODO: can't we just copy? Why do we need to rely on setting inside __dict__?
new_batch = Batch()
for key in batch.get_keys():
new_batch.__dict__[key] = batch[key]
batch = new_batch
batch.__dict__["done"] = np.logical_or(batch.terminated, batch.truncated)
assert {"obs", "act", "rew", "terminated", "truncated", "done"}.issubset(

# has to be done after preprocess batch
if not self._required_keys_for_add.issubset(
batch.get_keys(),
) # important to do after preprocess batch
stacked_batch = buffer_ids is not None
if stacked_batch:
assert len(batch) == 1
):
raise ValueError(
f"Input batch must have the following keys: {self._required_keys_for_add}",
)

batch_is_stacked = False
"""True when instead of passing a batch of shape (len(data)), a batch of shape (1, len(data)) is passed."""

if buffer_ids is not None:
if len(buffer_ids) != 1 and buffer_ids[0] != 0:
raise ValueError(
"If `buffer_ids` is not None, it must be a single element with value 0 for the non-vectorized `ReplayBuffer`. "
f"Got {buffer_ids=}.",
)
if len(batch) != 1:
raise ValueError(
f"If `buffer_ids` is not None, the batch must have the shape (1, len(data)) but got {len(batch)=}.",
)
batch_is_stacked = True

# block dealing with exotic options that are currently only used for atari, see various TODOs about that
# These options have interactions with the case when buffer_ids is not None
if self._save_only_last_obs:
batch.obs = batch.obs[:, -1] if stacked_batch else batch.obs[-1]
batch.obs = batch.obs[:, -1] if batch_is_stacked else batch.obs[-1]
if not self._save_obs_next:
batch.pop("obs_next", None)
elif self._save_only_last_obs:
batch.obs_next = batch.obs_next[:, -1] if stacked_batch else batch.obs_next[-1]
# get ptr
if stacked_batch:
batch.obs_next = batch.obs_next[:, -1] if batch_is_stacked else batch.obs_next[-1]

if batch_is_stacked:
rew, done = batch.rew[0], batch.done[0]
else:
rew, done = batch.rew, batch.done
insertion_idx, ep_return, ep_len, ep_start_idx = (
np.array([x]) for x in self._update_state_pre_add(rew, done)
)

# TODO: improve this, don'r rely on try-except, instead process the batch if needed
try:
self._meta[insertion_idx] = batch
except ValueError:
stack = not stacked_batch
stack = not batch_is_stacked
batch.rew = batch.rew.astype(float)
batch.done = batch.done.astype(bool)
batch.terminated = batch.terminated.astype(bool)
Expand Down
Loading

0 comments on commit 16f2fc2

Please sign in to comment.