Skip to content

Commit

Permalink
Merge branch 'master' into dependabot/pip/setuptools-70.0.0
Browse files Browse the repository at this point in the history
  • Loading branch information
Trinkle23897 authored Jul 18, 2024
2 parents c150fcc + db8072a commit 2c81cb9
Show file tree
Hide file tree
Showing 10 changed files with 73 additions and 105 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
- policy:
- introduced attribute `in_training_step` that is controlled by the trainer. #1123
- policy automatically set to `eval` mode when collecting and to `train` mode when updating. #1123
- Extended interface of `compute_action` to also support array-like inputs #1169
- `highlevel`:
- `SamplingConfig`:
- Add support for `batch_size=None`. #1077
Expand Down Expand Up @@ -86,6 +87,7 @@ instead of just `nn.Module`. #1032
Can be considered a bugfix. #1063
- The methods `to_numpy` and `to_torch` in are not in-place anymore
(use `to_numpy_` or `to_torch_` instead). #1098, #1117
- The method `Batch.is_empty` has been removed. Instead, the user can simply check for emptiness of Batch by using `len` on dicts. #1144
- Logging:
- `BaseLogger.prepare_dict_for_logging` is now abstract. #1074
- Removed deprecated and unused `BasicLogger` (only affects users who subclassed it). #1074
Expand Down
18 changes: 9 additions & 9 deletions docs/01_tutorials/03_batch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -324,35 +324,35 @@ Still, we can use a tree (in the right) to show the structure of ``Batch`` objec

Reserved keys mean that in the future there will eventually be values attached to them. The values can be scalars, tensors, or even **Batch** objects. Understanding this is critical to understand the behavior of ``Batch`` when dealing with heterogeneous Batches.

The introduction of reserved keys gives rise to the need to check if a key is reserved. Tianshou provides ``Batch.is_empty`` to achieve this.
The introduction of reserved keys gives rise to the need to check if a key is reserved.

.. raw:: html

<details>
<summary>Examples of Batch.is_empty</summary>
<summary>Examples of checking whether Batch is empty</summary>

.. code-block:: python
>>> Batch().is_empty()
>>> len(Batch().get_keys()) == 0
True
>>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty()
>>> len(Batch(a=Batch(), b=Batch(c=Batch())).get_keys()) == 0
False
>>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True)
>>> len(Batch(a=Batch(), b=Batch(c=Batch()))) == 0
True
>>> Batch(d=1).is_empty()
>>> len(Batch(d=1).get_keys()) == 0
False
>>> Batch(a=np.float64(1.0)).is_empty()
>>> len(Batch(a=np.float64(1.0)).get_keys()) == 0
False
.. raw:: html

</details><br>

The ``Batch.is_empty`` function has an option to decide whether to identify direct emptiness (just a ``Batch()``) or to identify recursive emptiness (a ``Batch`` object without any scalar/tensor leaf nodes).
To check whether a Batch is empty, simply use ``len(Batch.get_keys()) == 0`` to decide whether to identify direct emptiness (just a ``Batch()``) or ``len(Batch) == 0`` to identify recursive emptiness (a ``Batch`` object without any scalar/tensor leaf nodes).

.. note::

Do not get confused with ``Batch.is_empty`` and ``Batch.empty``. ``Batch.empty`` and its in-place variant ``Batch.empty_`` are used to set some values to zeros or None. Check the API documentation for further details.
Do not get confused with ``Batch.empty``. ``Batch.empty`` and its in-place variant ``Batch.empty_`` are used to set some values to zeros or None. Check the API documentation for further details.


Length and Shape
Expand Down
21 changes: 13 additions & 8 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 18 additions & 18 deletions test/base/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,23 @@

def test_batch() -> None:
assert list(Batch()) == []
assert Batch().is_empty()
assert not Batch(b={"c": {}}).is_empty()
assert Batch(b={"c": {}}).is_empty(recurse=True)
assert not Batch(a=Batch(), b=Batch(c=Batch())).is_empty()
assert Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True)
assert not Batch(d=1).is_empty()
assert not Batch(a=np.float64(1.0)).is_empty()
assert len(Batch().get_keys()) == 0
assert len(Batch(b={"c": {}}).get_keys()) != 0
assert len(Batch(b={"c": {}})) == 0
assert len(Batch(a=Batch(), b=Batch(c=Batch())).get_keys()) != 0
assert len(Batch(a=Batch(), b=Batch(c=Batch()))) == 0
assert len(Batch(d=1).get_keys()) != 0
assert len(Batch(a=np.float64(1.0)).get_keys()) != 0
assert len(Batch(a=[1, 2, 3], b={"c": {}})) == 3
assert not Batch(a=[1, 2, 3]).is_empty()
assert len(Batch(a=[1, 2, 3]).get_keys()) != 0
b = Batch({"a": [4, 4], "b": [5, 5]}, c=[None, None])
assert b.c.dtype == object
b = Batch(d=[None], e=[starmap], f=Batch)
assert b.d.dtype == b.e.dtype == object
assert b.f == Batch
b = Batch()
b.update()
assert b.is_empty()
assert len(b.get_keys()) == 0
b.update(c=[3, 5])
assert np.allclose(b.c, [3, 5])
# mimic the behavior of dict.update, where kwargs can overwrite keys
Expand Down Expand Up @@ -141,7 +141,7 @@ def test_batch() -> None:
assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2
assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2
assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2
assert batch2_sum.a.d.f.is_empty()
assert len(batch2_sum.a.d.f.get_keys()) == 0
with pytest.raises(TypeError):
batch2 += [1] # type: ignore # error is raised explicitly
batch3 = Batch(a={"c": np.zeros(1), "d": Batch(e=np.array([0.0]), f=np.array([3.0]))})
Expand Down Expand Up @@ -255,7 +255,7 @@ def test_batch_cat_and_stack() -> None:
ans = Batch.cat([a, b, a])
assert np.allclose(ans.a.a, np.concatenate([a.a.a, np.zeros((3, 4)), a.a.a]))
assert np.allclose(ans.b, np.concatenate([a.b, b.b, a.b]))
assert ans.a.t.is_empty()
assert len(ans.a.t.get_keys()) == 0

b1.stack_([b2])
assert isinstance(b1.a.d.e, np.ndarray)
Expand Down Expand Up @@ -296,7 +296,7 @@ def test_batch_cat_and_stack() -> None:
b=torch.cat([torch.zeros(3, 3), b2.b]),
common=Batch(c=np.concatenate([b1.common.c, b2.common.c])),
)
assert ans.a.is_empty()
assert len(ans.a.get_keys()) == 0
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)

Expand Down Expand Up @@ -325,7 +325,7 @@ def test_batch_cat_and_stack() -> None:
assert np.allclose(d.d, [0, 6, 9])

# test stack with empty Batch()
assert Batch.stack([Batch(), Batch(), Batch()]).is_empty()
assert len(Batch.stack([Batch(), Batch(), Batch()]).get_keys()) == 0
a = Batch(a=1, b=2, c=3, d=Batch(), e=Batch())
b = Batch(a=4, b=5, d=6, e=Batch())
c = Batch(c=7, b=6, d=9, e=Batch())
Expand All @@ -334,12 +334,12 @@ def test_batch_cat_and_stack() -> None:
assert np.allclose(d.b, [2, 5, 6])
assert np.allclose(d.c, [3, 0, 7])
assert np.allclose(d.d, [0, 6, 9])
assert d.e.is_empty()
assert len(d.e.get_keys()) == 0
b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(4, 5)))
b2 = Batch(b=Batch(), common=Batch(c=np.random.rand(4, 5)))
test = Batch.stack([b1, b2], axis=-1)
assert test.a.is_empty()
assert test.b.is_empty()
assert len(test.a.get_keys()) == 0
assert len(test.b.get_keys()) == 0
assert np.allclose(test.common.c, np.stack([b1.common.c, b2.common.c], axis=-1))

b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5)))
Expand All @@ -362,9 +362,9 @@ def test_batch_cat_and_stack() -> None:

# exceptions
batch_cat: Batch = Batch.cat([])
assert batch_cat.is_empty()
assert len(batch_cat.get_keys()) == 0
batch_stack: Batch = Batch.stack([])
assert batch_stack.is_empty()
assert len(batch_stack.get_keys()) == 0
b1 = Batch(e=[4, 5], d=6)
b2 = Batch(e=[4, 6])
with pytest.raises(ValueError):
Expand Down
4 changes: 2 additions & 2 deletions test/base/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,5 +1378,5 @@ def test_custom_key() -> None:
sampled_batch.__dict__[key],
Batch,
):
assert batch.__dict__[key].is_empty()
assert sampled_batch.__dict__[key].is_empty()
assert len(batch.__dict__[key].get_keys()) == 0
assert len(sampled_batch.__dict__[key].get_keys()) == 0
Loading

0 comments on commit 2c81cb9

Please sign in to comment.