diff --git a/docs/01_tutorials/03_batch.rst b/docs/01_tutorials/03_batch.rst
index 08a335948..c19d884cf 100644
--- a/docs/01_tutorials/03_batch.rst
+++ b/docs/01_tutorials/03_batch.rst
@@ -324,35 +324,35 @@ Still, we can use a tree (in the right) to show the structure of ``Batch`` objec
Reserved keys mean that in the future there will eventually be values attached to them. The values can be scalars, tensors, or even **Batch** objects. Understanding this is critical to understand the behavior of ``Batch`` when dealing with heterogeneous Batches.
-The introduction of reserved keys gives rise to the need to check if a key is reserved. Tianshou provides ``Batch.is_empty`` to achieve this.
+The introduction of reserved keys gives rise to the need to check if a key is reserved.
.. raw:: html
- Examples of Batch.is_empty
+ Examples of checking whether Batch is empty
.. code-block:: python
- >>> Batch().is_empty()
+ >>> len(Batch().get_keys()) == 0
True
- >>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty()
+ >>> len(Batch(a=Batch(), b=Batch(c=Batch())).get_keys()) == 0
False
- >>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True)
+ >>> len(Batch(a=Batch(), b=Batch(c=Batch()))) == 0
True
- >>> Batch(d=1).is_empty()
+ >>> len(Batch(d=1).get_keys()) == 0
False
- >>> Batch(a=np.float64(1.0)).is_empty()
+ >>> len(Batch(a=np.float64(1.0)).get_keys()) == 0
False
.. raw:: html
-The ``Batch.is_empty`` function has an option to decide whether to identify direct emptiness (just a ``Batch()``) or to identify recursive emptiness (a ``Batch`` object without any scalar/tensor leaf nodes).
+To check whether a Batch is empty, simply use `len(Batch.get_keys()) == 0` to decide whether to identify direct emptiness (just a ``Batch()``) or `len(Batch) == 0` to identify recursive emptiness (a ``Batch`` object without any scalar/tensor leaf nodes).
.. note::
- Do not get confused with ``Batch.is_empty`` and ``Batch.empty``. ``Batch.empty`` and its in-place variant ``Batch.empty_`` are used to set some values to zeros or None. Check the API documentation for further details.
+ Do not get confused with ``Batch.empty``. ``Batch.empty`` and its in-place variant ``Batch.empty_`` are used to set some values to zeros or None. Check the API documentation for further details.
Length and Shape
diff --git a/test/base/test_batch.py b/test/base/test_batch.py
index 86d4af500..0530d8232 100644
--- a/test/base/test_batch.py
+++ b/test/base/test_batch.py
@@ -15,15 +15,15 @@
def test_batch() -> None:
assert list(Batch()) == []
- assert Batch().is_empty()
- assert not Batch(b={"c": {}}).is_empty()
- assert Batch(b={"c": {}}).is_empty(recurse=True)
- assert not Batch(a=Batch(), b=Batch(c=Batch())).is_empty()
- assert Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True)
- assert not Batch(d=1).is_empty()
- assert not Batch(a=np.float64(1.0)).is_empty()
+ assert len(Batch().get_keys()) == 0
+ assert len(Batch(b={"c": {}}).get_keys()) != 0
+ assert len(Batch(b={"c": {}})) == 0
+ assert len(Batch(a=Batch(), b=Batch(c=Batch())).get_keys()) != 0
+ assert len(Batch(a=Batch(), b=Batch(c=Batch()))) == 0
+ assert len(Batch(d=1).get_keys()) != 0
+ assert len(Batch(a=np.float64(1.0)).get_keys()) != 0
assert len(Batch(a=[1, 2, 3], b={"c": {}})) == 3
- assert not Batch(a=[1, 2, 3]).is_empty()
+ assert len(Batch(a=[1, 2, 3]).get_keys()) != 0
b = Batch({"a": [4, 4], "b": [5, 5]}, c=[None, None])
assert b.c.dtype == object
b = Batch(d=[None], e=[starmap], f=Batch)
@@ -31,7 +31,7 @@ def test_batch() -> None:
assert b.f == Batch
b = Batch()
b.update()
- assert b.is_empty()
+ assert len(b.get_keys()) == 0
b.update(c=[3, 5])
assert np.allclose(b.c, [3, 5])
# mimic the behavior of dict.update, where kwargs can overwrite keys
@@ -141,7 +141,7 @@ def test_batch() -> None:
assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2
assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2
assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2
- assert batch2_sum.a.d.f.is_empty()
+ assert len(batch2_sum.a.d.f.get_keys()) == 0
with pytest.raises(TypeError):
batch2 += [1] # type: ignore # error is raised explicitly
batch3 = Batch(a={"c": np.zeros(1), "d": Batch(e=np.array([0.0]), f=np.array([3.0]))})
@@ -255,7 +255,7 @@ def test_batch_cat_and_stack() -> None:
ans = Batch.cat([a, b, a])
assert np.allclose(ans.a.a, np.concatenate([a.a.a, np.zeros((3, 4)), a.a.a]))
assert np.allclose(ans.b, np.concatenate([a.b, b.b, a.b]))
- assert ans.a.t.is_empty()
+ assert len(ans.a.t.get_keys()) == 0
b1.stack_([b2])
assert isinstance(b1.a.d.e, np.ndarray)
@@ -296,7 +296,7 @@ def test_batch_cat_and_stack() -> None:
b=torch.cat([torch.zeros(3, 3), b2.b]),
common=Batch(c=np.concatenate([b1.common.c, b2.common.c])),
)
- assert ans.a.is_empty()
+ assert len(ans.a.get_keys()) == 0
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)
@@ -325,7 +325,7 @@ def test_batch_cat_and_stack() -> None:
assert np.allclose(d.d, [0, 6, 9])
# test stack with empty Batch()
- assert Batch.stack([Batch(), Batch(), Batch()]).is_empty()
+ assert len(Batch.stack([Batch(), Batch(), Batch()]).get_keys()) == 0
a = Batch(a=1, b=2, c=3, d=Batch(), e=Batch())
b = Batch(a=4, b=5, d=6, e=Batch())
c = Batch(c=7, b=6, d=9, e=Batch())
@@ -334,12 +334,12 @@ def test_batch_cat_and_stack() -> None:
assert np.allclose(d.b, [2, 5, 6])
assert np.allclose(d.c, [3, 0, 7])
assert np.allclose(d.d, [0, 6, 9])
- assert d.e.is_empty()
+ assert len(d.e.get_keys()) == 0
b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(4, 5)))
b2 = Batch(b=Batch(), common=Batch(c=np.random.rand(4, 5)))
test = Batch.stack([b1, b2], axis=-1)
- assert test.a.is_empty()
- assert test.b.is_empty()
+ assert len(test.a.get_keys()) == 0
+ assert len(test.b.get_keys()) == 0
assert np.allclose(test.common.c, np.stack([b1.common.c, b2.common.c], axis=-1))
b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5)))
@@ -362,9 +362,9 @@ def test_batch_cat_and_stack() -> None:
# exceptions
batch_cat: Batch = Batch.cat([])
- assert batch_cat.is_empty()
+ assert len(batch_cat.get_keys()) == 0
batch_stack: Batch = Batch.stack([])
- assert batch_stack.is_empty()
+ assert len(batch_stack.get_keys()) == 0
b1 = Batch(e=[4, 5], d=6)
b2 = Batch(e=[4, 6])
with pytest.raises(ValueError):
diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py
index 1b3593db3..9f3f40828 100644
--- a/test/base/test_buffer.py
+++ b/test/base/test_buffer.py
@@ -1378,5 +1378,5 @@ def test_custom_key() -> None:
sampled_batch.__dict__[key],
Batch,
):
- assert batch.__dict__[key].is_empty()
- assert sampled_batch.__dict__[key].is_empty()
+ assert len(batch.__dict__[key].get_keys()) == 0
+ assert len(sampled_batch.__dict__[key].get_keys()) == 0
diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py
index e4ea4a8b1..04505ace3 100644
--- a/tianshou/data/batch.py
+++ b/tianshou/data/batch.py
@@ -393,9 +393,6 @@ def update(self, batch: dict | Self | None = None, **kwargs: Any) -> None:
def __len__(self) -> int:
...
- def is_empty(self, recurse: bool = False) -> bool:
- ...
-
def split(
self,
size: int,
@@ -514,7 +511,7 @@ def __getitem__(self, index: str | IndexType) -> Any:
if len(batch_items) > 0:
new_batch = Batch()
for batch_key, obj in batch_items:
- if isinstance(obj, Batch) and obj.is_empty():
+ if isinstance(obj, Batch) and len(obj.get_keys()) == 0:
new_batch.__dict__[batch_key] = Batch()
else:
new_batch.__dict__[batch_key] = obj[index]
@@ -574,13 +571,13 @@ def __iadd__(self, other: Self | Number | np.number) -> Self:
other.__dict__.values(),
strict=True,
): # TODO are keys consistent?
- if isinstance(obj, Batch) and obj.is_empty():
+ if isinstance(obj, Batch) and len(obj.get_keys()) == 0:
continue
self.__dict__[batch_key] += value
return self
if _is_number(other):
for batch_key, obj in self.items():
- if isinstance(obj, Batch) and obj.is_empty():
+ if isinstance(obj, Batch) and len(obj.get_keys()) == 0:
continue
self.__dict__[batch_key] += other
return self
@@ -594,7 +591,7 @@ def __imul__(self, value: Number | np.number) -> Self:
"""Algebraic multiplication with a scalar value in-place."""
assert _is_number(value), "Only multiplication by a number is supported."
for batch_key, obj in self.__dict__.items():
- if isinstance(obj, Batch) and obj.is_empty():
+ if isinstance(obj, Batch) and len(obj.get_keys()) == 0:
continue
self.__dict__[batch_key] *= value
return self
@@ -607,7 +604,7 @@ def __itruediv__(self, value: Number | np.number) -> Self:
"""Algebraic division with a scalar value in-place."""
assert _is_number(value), "Only division by a number is supported."
for batch_key, obj in self.__dict__.items():
- if isinstance(obj, Batch) and obj.is_empty():
+ if isinstance(obj, Batch) and len(obj.get_keys()) == 0:
continue
self.__dict__[batch_key] /= value
return self
@@ -722,7 +719,7 @@ def __cat(self, batches: Sequence[dict | Self], lens: list[int]) -> None:
{
batch_key
for batch_key, obj in batch.items()
- if not (isinstance(obj, Batch) and obj.is_empty())
+ if not (isinstance(obj, Batch) and len(obj.get_keys()) == 0)
}
for batch in batches
]
@@ -753,7 +750,7 @@ def __cat(self, batches: Sequence[dict | Self], lens: list[int]) -> None:
if key not in batch.__dict__:
continue
value = batch.get(key)
- if isinstance(value, Batch) and value.is_empty():
+ if isinstance(value, Batch) and len(value.get_keys()) == 0:
continue
try:
self.__dict__[key][sum_lens[i] : sum_lens[i + 1]] = value
@@ -772,7 +769,7 @@ def cat_(self, batches: BatchProtocol | Sequence[dict | BatchProtocol]) -> None:
batch_list.append(Batch(batch))
elif isinstance(batch, Batch):
# x.is_empty() means that x is Batch() and should be ignored
- if not batch.is_empty():
+ if len(batch.get_keys()) != 0:
batch_list.append(batch)
else:
raise ValueError(f"Cannot concatenate {type(batch)} in Batch.cat_")
@@ -783,16 +780,16 @@ def cat_(self, batches: BatchProtocol | Sequence[dict | BatchProtocol]) -> None:
# x.is_empty(recurse=True) here means x is a nested empty batch
# like Batch(a=Batch), and we have to treat it as length zero and
# keep it.
- lens = [0 if batch.is_empty(recurse=True) else len(batch) for batch in batches]
+ lens = [0 if len(batch) == 0 else len(batch) for batch in batches]
except TypeError as exception:
raise ValueError(
"Batch.cat_ meets an exception. Maybe because there is any "
f"scalar in {batches} but Batch.cat_ does not support the "
"concatenation of scalar.",
) from exception
- if not self.is_empty():
+ if len(self.get_keys()) != 0:
batches = [self, *list(batches)]
- lens = [0 if self.is_empty(recurse=True) else len(self), *lens]
+ lens = [0 if len(self) == 0 else len(self), *lens]
self.__cat(batches, lens)
@staticmethod
@@ -810,14 +807,14 @@ def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None
batch_list.append(Batch(batch))
elif isinstance(batch, Batch):
# x.is_empty() means that x is Batch() and should be ignored
- if not batch.is_empty():
+ if len(batch.get_keys()) != 0:
batch_list.append(batch)
else:
raise ValueError(f"Cannot concatenate {type(batch)} in Batch.stack_")
if len(batch_list) == 0:
return
batches = batch_list
- if not self.is_empty():
+ if len(self.get_keys()) != 0:
batches = [self, *batches]
# collect non-empty keys
keys_map = [
@@ -930,7 +927,7 @@ def __len__(self) -> int:
# TODO: causes inconsistent behavior to batch with empty batches
# and batch with empty sequences of other type. Remove, but only after
# Buffer and Collectors have been improved to no longer rely on this
- if isinstance(obj, Batch) and obj.is_empty(recurse=True):
+ if isinstance(obj, Batch) and len(obj) == 0:
continue
if hasattr(obj, "__len__") and (isinstance(obj, Batch) or obj.ndim > 0):
lens.append(len(obj))
@@ -940,45 +937,10 @@ def __len__(self) -> int:
return 0
return min(lens)
- def is_empty(self, recurse: bool = False) -> bool:
- """Test if a Batch is empty.
-
- If ``recurse=True``, it further tests the values of the object; else
- it only tests the existence of any key.
-
- ``b.is_empty(recurse=True)`` is mainly used to distinguish
- ``Batch(a=Batch(a=Batch()))`` and ``Batch(a=1)``. They both raise
- exceptions when applied to ``len()``, but the former can be used in
- ``cat``, while the latter is a scalar and cannot be used in ``cat``.
-
- Another usage is in ``__len__``, where we have to skip checking the
- length of recursively empty Batch.
- ::
-
- >>> Batch().is_empty()
- True
- >>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty()
- False
- >>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True)
- True
- >>> Batch(d=1).is_empty()
- False
- >>> Batch(a=np.float64(1.0)).is_empty()
- False
- """
- if len(self.__dict__) == 0:
- return True
- if not recurse:
- return False
- return all(
- False if not isinstance(obj, Batch) else obj.is_empty(recurse=True)
- for obj in self.values()
- )
-
@property
def shape(self) -> list[int]:
"""Return self.shape."""
- if self.is_empty():
+ if len(self.get_keys()) == 0:
return []
data_shape = []
for obj in self.__dict__.values():
diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/base.py
index b1719be56..2ecddc5ce 100644
--- a/tianshou/data/buffer/base.py
+++ b/tianshou/data/buffer/base.py
@@ -208,7 +208,7 @@ def update(self, buffer: "ReplayBuffer") -> np.ndarray:
self._index = (self._index + 1) % self.maxsize
self._size = min(self._size + 1, self.maxsize)
to_indices = np.array(to_indices)
- if self._meta.is_empty():
+ if len(self._meta.get_keys()) == 0:
self._meta = create_value(buffer._meta, self.maxsize, stack=False) # type: ignore
self._meta[to_indices] = buffer._meta[from_indices]
return to_indices
@@ -284,7 +284,7 @@ def add(
batch.done = batch.done.astype(bool)
batch.terminated = batch.terminated.astype(bool)
batch.truncated = batch.truncated.astype(bool)
- if self._meta.is_empty():
+ if len(self._meta.get_keys()) == 0:
self._meta = create_value(batch, self.maxsize, stack) # type: ignore
else: # dynamic key pops up in batch
alloc_by_keys_diff(self._meta, batch, self.maxsize, stack)
@@ -377,7 +377,7 @@ def get(
return np.stack(stack, axis=indices.ndim)
except IndexError as exception:
- if not (isinstance(val, Batch) and val.is_empty()):
+ if not (isinstance(val, Batch) and len(val.get_keys()) == 0):
raise exception # val != Batch()
return Batch()
diff --git a/tianshou/data/buffer/manager.py b/tianshou/data/buffer/manager.py
index 90480257a..38db7d120 100644
--- a/tianshou/data/buffer/manager.py
+++ b/tianshou/data/buffer/manager.py
@@ -29,7 +29,7 @@ def __init__(self, buffer_list: list[ReplayBuffer] | list[HERReplayBuffer]) -> N
buffer_type = type(self.buffers[0])
kwargs = self.buffers[0].options
for buf in self.buffers:
- assert buf._meta.is_empty()
+ assert len(buf._meta.get_keys()) == 0
assert isinstance(buf, buffer_type)
assert buf.options == kwargs
offset.append(size)
@@ -161,7 +161,7 @@ def add(
batch.done = batch.done.astype(bool)
batch.terminated = batch.terminated.astype(bool)
batch.truncated = batch.truncated.astype(bool)
- if self._meta.is_empty():
+ if len(self._meta.get_keys()) == 0:
self._meta = create_value(batch, self.maxsize, stack=False) # type: ignore
else: # dynamic key pops up in batch
alloc_by_keys_diff(self._meta, batch, self.maxsize, False)
diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py
index 81cfe0a6d..e88214d45 100644
--- a/tianshou/policy/multiagent/mapolicy.py
+++ b/tianshou/policy/multiagent/mapolicy.py
@@ -272,7 +272,7 @@ def learn( # type: ignore
agent_id_to_stats = {}
for agent_id, policy in self.policies.items():
data = batch[agent_id]
- if not data.is_empty():
+ if len(data.get_keys()) != 0:
train_stats = policy.learn(batch=data, **kwargs)
agent_id_to_stats[agent_id] = train_stats
return MapTrainingStats(agent_id_to_stats)