Skip to content

Commit

Permalink
Fixed off-by-1 indexing error, added one more test
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Panchenko committed Sep 2, 2024
1 parent 68dec19 commit c903f72
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
8 changes: 4 additions & 4 deletions docs/02_notebooks/L6_Trainer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@
"outputs": [],
"source": [
"train_env_num = 4\n",
"buffer_size = (\n",
" 2000 # Since REINFORCE is an on-policy algorithm, we don't need a very large buffer size\n",
")\n",
"# Since REINFORCE is an on-policy algorithm, we don't need a very large buffer size\n",
"buffer_size = 2000\n",
"\n",
"\n",
"# Create the environments, used for training and evaluation\n",
"env = gym.make(\"CartPole-v1\")\n",
Expand Down Expand Up @@ -275,7 +275,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
"version": "3.11.4"
}
},
"nbformat": 4,
Expand Down
1 change: 1 addition & 0 deletions test/base/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1528,6 +1528,7 @@ def test_get_replay_buffer_indices(dummy_rollout_batch: RolloutBatchProtocol) ->
buffer.add(dummy_rollout_batch)
assert np.array_equal(buffer.get_buffer_indices(0, 3), [0, 1, 2])
assert np.array_equal(buffer.get_buffer_indices(3, 2), [3, 4, 0, 1])
assert np.array_equal(buffer.get_buffer_indices(0, 5), np.arange(5))


def test_get_vector_replay_buffer_indices(dummy_rollout_batch: RolloutBatchProtocol) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tianshou/data/buffer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def get_buffer_indices(self, start: int, stop: int) -> np.ndarray:
:return: The indices of the transitions in the buffer between start and stop.
"""
start_left_edge = np.searchsorted(self.subbuffer_edges, start, side="right") - 1
stop_left_edge = np.searchsorted(self.subbuffer_edges, stop, side="right") - 1
stop_left_edge = np.searchsorted(self.subbuffer_edges, stop - 1, side="right") - 1
if start_left_edge != stop_left_edge:
raise ValueError(
f"Start and stop indices must be within the same subbuffer. "
Expand Down

0 comments on commit c903f72

Please sign in to comment.