From 1a4060e8d8089f2e171b0ce2c89fcc3ebec0903a Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Fri, 10 May 2024 18:18:26 +0200 Subject: [PATCH] Remove Batch.is_empty * Keeps the API lean * The same functionality can be achieved using len() on dicts --- docs/01_tutorials/03_batch.rst | 18 +++---- test/base/test_batch.py | 36 +++++++------- test/base/test_buffer.py | 4 +- tianshou/data/batch.py | 68 ++++++-------------------- tianshou/data/buffer/base.py | 6 +-- tianshou/data/buffer/manager.py | 4 +- tianshou/policy/multiagent/mapolicy.py | 2 +- 7 files changed, 50 insertions(+), 88 deletions(-) diff --git a/docs/01_tutorials/03_batch.rst b/docs/01_tutorials/03_batch.rst index 08a335948..c19d884cf 100644 --- a/docs/01_tutorials/03_batch.rst +++ b/docs/01_tutorials/03_batch.rst @@ -324,35 +324,35 @@ Still, we can use a tree (in the right) to show the structure of ``Batch`` objec Reserved keys mean that in the future there will eventually be values attached to them. The values can be scalars, tensors, or even **Batch** objects. Understanding this is critical to understand the behavior of ``Batch`` when dealing with heterogeneous Batches. -The introduction of reserved keys gives rise to the need to check if a key is reserved. Tianshou provides ``Batch.is_empty`` to achieve this. +The introduction of reserved keys gives rise to the need to check if a key is reserved. .. raw:: html
- Examples of Batch.is_empty + Examples of checking whether Batch is empty .. code-block:: python - >>> Batch().is_empty() + >>> len(Batch().get_keys()) == 0 True - >>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty() + >>> len(Batch(a=Batch(), b=Batch(c=Batch())).get_keys()) == 0 False - >>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True) + >>> len(Batch(a=Batch(), b=Batch(c=Batch()))) == 0 True - >>> Batch(d=1).is_empty() + >>> len(Batch(d=1).get_keys()) == 0 False - >>> Batch(a=np.float64(1.0)).is_empty() + >>> len(Batch(a=np.float64(1.0)).get_keys()) == 0 False .. raw:: html

-The ``Batch.is_empty`` function has an option to decide whether to identify direct emptiness (just a ``Batch()``) or to identify recursive emptiness (a ``Batch`` object without any scalar/tensor leaf nodes). +To check whether a Batch is empty, simply use `len(Batch.get_keys()) == 0` to decide whether to identify direct emptiness (just a ``Batch()``) or `len(Batch) == 0` to identify recursive emptiness (a ``Batch`` object without any scalar/tensor leaf nodes). .. note:: - Do not get confused with ``Batch.is_empty`` and ``Batch.empty``. ``Batch.empty`` and its in-place variant ``Batch.empty_`` are used to set some values to zeros or None. Check the API documentation for further details. + Do not get confused with ``Batch.empty``. ``Batch.empty`` and its in-place variant ``Batch.empty_`` are used to set some values to zeros or None. Check the API documentation for further details. Length and Shape diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 86d4af500..0530d8232 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -15,15 +15,15 @@ def test_batch() -> None: assert list(Batch()) == [] - assert Batch().is_empty() - assert not Batch(b={"c": {}}).is_empty() - assert Batch(b={"c": {}}).is_empty(recurse=True) - assert not Batch(a=Batch(), b=Batch(c=Batch())).is_empty() - assert Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True) - assert not Batch(d=1).is_empty() - assert not Batch(a=np.float64(1.0)).is_empty() + assert len(Batch().get_keys()) == 0 + assert len(Batch(b={"c": {}}).get_keys()) != 0 + assert len(Batch(b={"c": {}})) == 0 + assert len(Batch(a=Batch(), b=Batch(c=Batch())).get_keys()) != 0 + assert len(Batch(a=Batch(), b=Batch(c=Batch()))) == 0 + assert len(Batch(d=1).get_keys()) != 0 + assert len(Batch(a=np.float64(1.0)).get_keys()) != 0 assert len(Batch(a=[1, 2, 3], b={"c": {}})) == 3 - assert not Batch(a=[1, 2, 3]).is_empty() + assert len(Batch(a=[1, 2, 3]).get_keys()) != 0 b = Batch({"a": [4, 4], "b": [5, 5]}, c=[None, None]) assert b.c.dtype == object b = Batch(d=[None], e=[starmap], f=Batch) @@ -31,7 +31,7 @@ def test_batch() -> None: assert b.f == Batch b = Batch() b.update() - assert b.is_empty() + assert len(b.get_keys()) == 0 b.update(c=[3, 5]) assert np.allclose(b.c, [3, 5]) # mimic the behavior of dict.update, where kwargs can overwrite keys @@ -141,7 +141,7 @@ def test_batch() -> None: assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2 assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2 assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2 - assert batch2_sum.a.d.f.is_empty() + assert len(batch2_sum.a.d.f.get_keys()) == 0 with pytest.raises(TypeError): batch2 += [1] # type: ignore # error is raised explicitly batch3 = Batch(a={"c": np.zeros(1), "d": Batch(e=np.array([0.0]), f=np.array([3.0]))}) @@ -255,7 +255,7 @@ def test_batch_cat_and_stack() -> None: ans = Batch.cat([a, b, a]) assert np.allclose(ans.a.a, np.concatenate([a.a.a, np.zeros((3, 4)), a.a.a])) assert np.allclose(ans.b, np.concatenate([a.b, b.b, a.b])) - assert ans.a.t.is_empty() + assert len(ans.a.t.get_keys()) == 0 b1.stack_([b2]) assert isinstance(b1.a.d.e, np.ndarray) @@ -296,7 +296,7 @@ def test_batch_cat_and_stack() -> None: b=torch.cat([torch.zeros(3, 3), b2.b]), common=Batch(c=np.concatenate([b1.common.c, b2.common.c])), ) - assert ans.a.is_empty() + assert len(ans.a.get_keys()) == 0 assert torch.allclose(test.b, ans.b) assert np.allclose(test.common.c, ans.common.c) @@ -325,7 +325,7 @@ def test_batch_cat_and_stack() -> None: assert np.allclose(d.d, [0, 6, 9]) # test stack with empty Batch() - assert Batch.stack([Batch(), Batch(), Batch()]).is_empty() + assert len(Batch.stack([Batch(), Batch(), Batch()]).get_keys()) == 0 a = Batch(a=1, b=2, c=3, d=Batch(), e=Batch()) b = Batch(a=4, b=5, d=6, e=Batch()) c = Batch(c=7, b=6, d=9, e=Batch()) @@ -334,12 +334,12 @@ def test_batch_cat_and_stack() -> None: assert np.allclose(d.b, [2, 5, 6]) assert np.allclose(d.c, [3, 0, 7]) assert np.allclose(d.d, [0, 6, 9]) - assert d.e.is_empty() + assert len(d.e.get_keys()) == 0 b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(4, 5))) b2 = Batch(b=Batch(), common=Batch(c=np.random.rand(4, 5))) test = Batch.stack([b1, b2], axis=-1) - assert test.a.is_empty() - assert test.b.is_empty() + assert len(test.a.get_keys()) == 0 + assert len(test.b.get_keys()) == 0 assert np.allclose(test.common.c, np.stack([b1.common.c, b2.common.c], axis=-1)) b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5))) @@ -362,9 +362,9 @@ def test_batch_cat_and_stack() -> None: # exceptions batch_cat: Batch = Batch.cat([]) - assert batch_cat.is_empty() + assert len(batch_cat.get_keys()) == 0 batch_stack: Batch = Batch.stack([]) - assert batch_stack.is_empty() + assert len(batch_stack.get_keys()) == 0 b1 = Batch(e=[4, 5], d=6) b2 = Batch(e=[4, 6]) with pytest.raises(ValueError): diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 1b3593db3..9f3f40828 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -1378,5 +1378,5 @@ def test_custom_key() -> None: sampled_batch.__dict__[key], Batch, ): - assert batch.__dict__[key].is_empty() - assert sampled_batch.__dict__[key].is_empty() + assert len(batch.__dict__[key].get_keys()) == 0 + assert len(sampled_batch.__dict__[key].get_keys()) == 0 diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index e4ea4a8b1..04505ace3 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -393,9 +393,6 @@ def update(self, batch: dict | Self | None = None, **kwargs: Any) -> None: def __len__(self) -> int: ... - def is_empty(self, recurse: bool = False) -> bool: - ... - def split( self, size: int, @@ -514,7 +511,7 @@ def __getitem__(self, index: str | IndexType) -> Any: if len(batch_items) > 0: new_batch = Batch() for batch_key, obj in batch_items: - if isinstance(obj, Batch) and obj.is_empty(): + if isinstance(obj, Batch) and len(obj.get_keys()) == 0: new_batch.__dict__[batch_key] = Batch() else: new_batch.__dict__[batch_key] = obj[index] @@ -574,13 +571,13 @@ def __iadd__(self, other: Self | Number | np.number) -> Self: other.__dict__.values(), strict=True, ): # TODO are keys consistent? - if isinstance(obj, Batch) and obj.is_empty(): + if isinstance(obj, Batch) and len(obj.get_keys()) == 0: continue self.__dict__[batch_key] += value return self if _is_number(other): for batch_key, obj in self.items(): - if isinstance(obj, Batch) and obj.is_empty(): + if isinstance(obj, Batch) and len(obj.get_keys()) == 0: continue self.__dict__[batch_key] += other return self @@ -594,7 +591,7 @@ def __imul__(self, value: Number | np.number) -> Self: """Algebraic multiplication with a scalar value in-place.""" assert _is_number(value), "Only multiplication by a number is supported." for batch_key, obj in self.__dict__.items(): - if isinstance(obj, Batch) and obj.is_empty(): + if isinstance(obj, Batch) and len(obj.get_keys()) == 0: continue self.__dict__[batch_key] *= value return self @@ -607,7 +604,7 @@ def __itruediv__(self, value: Number | np.number) -> Self: """Algebraic division with a scalar value in-place.""" assert _is_number(value), "Only division by a number is supported." for batch_key, obj in self.__dict__.items(): - if isinstance(obj, Batch) and obj.is_empty(): + if isinstance(obj, Batch) and len(obj.get_keys()) == 0: continue self.__dict__[batch_key] /= value return self @@ -722,7 +719,7 @@ def __cat(self, batches: Sequence[dict | Self], lens: list[int]) -> None: { batch_key for batch_key, obj in batch.items() - if not (isinstance(obj, Batch) and obj.is_empty()) + if not (isinstance(obj, Batch) and len(obj.get_keys()) == 0) } for batch in batches ] @@ -753,7 +750,7 @@ def __cat(self, batches: Sequence[dict | Self], lens: list[int]) -> None: if key not in batch.__dict__: continue value = batch.get(key) - if isinstance(value, Batch) and value.is_empty(): + if isinstance(value, Batch) and len(value.get_keys()) == 0: continue try: self.__dict__[key][sum_lens[i] : sum_lens[i + 1]] = value @@ -772,7 +769,7 @@ def cat_(self, batches: BatchProtocol | Sequence[dict | BatchProtocol]) -> None: batch_list.append(Batch(batch)) elif isinstance(batch, Batch): # x.is_empty() means that x is Batch() and should be ignored - if not batch.is_empty(): + if len(batch.get_keys()) != 0: batch_list.append(batch) else: raise ValueError(f"Cannot concatenate {type(batch)} in Batch.cat_") @@ -783,16 +780,16 @@ def cat_(self, batches: BatchProtocol | Sequence[dict | BatchProtocol]) -> None: # x.is_empty(recurse=True) here means x is a nested empty batch # like Batch(a=Batch), and we have to treat it as length zero and # keep it. - lens = [0 if batch.is_empty(recurse=True) else len(batch) for batch in batches] + lens = [0 if len(batch) == 0 else len(batch) for batch in batches] except TypeError as exception: raise ValueError( "Batch.cat_ meets an exception. Maybe because there is any " f"scalar in {batches} but Batch.cat_ does not support the " "concatenation of scalar.", ) from exception - if not self.is_empty(): + if len(self.get_keys()) != 0: batches = [self, *list(batches)] - lens = [0 if self.is_empty(recurse=True) else len(self), *lens] + lens = [0 if len(self) == 0 else len(self), *lens] self.__cat(batches, lens) @staticmethod @@ -810,14 +807,14 @@ def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None batch_list.append(Batch(batch)) elif isinstance(batch, Batch): # x.is_empty() means that x is Batch() and should be ignored - if not batch.is_empty(): + if len(batch.get_keys()) != 0: batch_list.append(batch) else: raise ValueError(f"Cannot concatenate {type(batch)} in Batch.stack_") if len(batch_list) == 0: return batches = batch_list - if not self.is_empty(): + if len(self.get_keys()) != 0: batches = [self, *batches] # collect non-empty keys keys_map = [ @@ -930,7 +927,7 @@ def __len__(self) -> int: # TODO: causes inconsistent behavior to batch with empty batches # and batch with empty sequences of other type. Remove, but only after # Buffer and Collectors have been improved to no longer rely on this - if isinstance(obj, Batch) and obj.is_empty(recurse=True): + if isinstance(obj, Batch) and len(obj) == 0: continue if hasattr(obj, "__len__") and (isinstance(obj, Batch) or obj.ndim > 0): lens.append(len(obj)) @@ -940,45 +937,10 @@ def __len__(self) -> int: return 0 return min(lens) - def is_empty(self, recurse: bool = False) -> bool: - """Test if a Batch is empty. - - If ``recurse=True``, it further tests the values of the object; else - it only tests the existence of any key. - - ``b.is_empty(recurse=True)`` is mainly used to distinguish - ``Batch(a=Batch(a=Batch()))`` and ``Batch(a=1)``. They both raise - exceptions when applied to ``len()``, but the former can be used in - ``cat``, while the latter is a scalar and cannot be used in ``cat``. - - Another usage is in ``__len__``, where we have to skip checking the - length of recursively empty Batch. - :: - - >>> Batch().is_empty() - True - >>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty() - False - >>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True) - True - >>> Batch(d=1).is_empty() - False - >>> Batch(a=np.float64(1.0)).is_empty() - False - """ - if len(self.__dict__) == 0: - return True - if not recurse: - return False - return all( - False if not isinstance(obj, Batch) else obj.is_empty(recurse=True) - for obj in self.values() - ) - @property def shape(self) -> list[int]: """Return self.shape.""" - if self.is_empty(): + if len(self.get_keys()) == 0: return [] data_shape = [] for obj in self.__dict__.values(): diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/base.py index b1719be56..2ecddc5ce 100644 --- a/tianshou/data/buffer/base.py +++ b/tianshou/data/buffer/base.py @@ -208,7 +208,7 @@ def update(self, buffer: "ReplayBuffer") -> np.ndarray: self._index = (self._index + 1) % self.maxsize self._size = min(self._size + 1, self.maxsize) to_indices = np.array(to_indices) - if self._meta.is_empty(): + if len(self._meta.get_keys()) == 0: self._meta = create_value(buffer._meta, self.maxsize, stack=False) # type: ignore self._meta[to_indices] = buffer._meta[from_indices] return to_indices @@ -284,7 +284,7 @@ def add( batch.done = batch.done.astype(bool) batch.terminated = batch.terminated.astype(bool) batch.truncated = batch.truncated.astype(bool) - if self._meta.is_empty(): + if len(self._meta.get_keys()) == 0: self._meta = create_value(batch, self.maxsize, stack) # type: ignore else: # dynamic key pops up in batch alloc_by_keys_diff(self._meta, batch, self.maxsize, stack) @@ -377,7 +377,7 @@ def get( return np.stack(stack, axis=indices.ndim) except IndexError as exception: - if not (isinstance(val, Batch) and val.is_empty()): + if not (isinstance(val, Batch) and len(val.get_keys()) == 0): raise exception # val != Batch() return Batch() diff --git a/tianshou/data/buffer/manager.py b/tianshou/data/buffer/manager.py index 90480257a..38db7d120 100644 --- a/tianshou/data/buffer/manager.py +++ b/tianshou/data/buffer/manager.py @@ -29,7 +29,7 @@ def __init__(self, buffer_list: list[ReplayBuffer] | list[HERReplayBuffer]) -> N buffer_type = type(self.buffers[0]) kwargs = self.buffers[0].options for buf in self.buffers: - assert buf._meta.is_empty() + assert len(buf._meta.get_keys()) == 0 assert isinstance(buf, buffer_type) assert buf.options == kwargs offset.append(size) @@ -161,7 +161,7 @@ def add( batch.done = batch.done.astype(bool) batch.terminated = batch.terminated.astype(bool) batch.truncated = batch.truncated.astype(bool) - if self._meta.is_empty(): + if len(self._meta.get_keys()) == 0: self._meta = create_value(batch, self.maxsize, stack=False) # type: ignore else: # dynamic key pops up in batch alloc_by_keys_diff(self._meta, batch, self.maxsize, False) diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index 81cfe0a6d..e88214d45 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -272,7 +272,7 @@ def learn( # type: ignore agent_id_to_stats = {} for agent_id, policy in self.policies.items(): data = batch[agent_id] - if not data.is_empty(): + if len(data.get_keys()) != 0: train_stats = policy.learn(batch=data, **kwargs) agent_id_to_stats[agent_id] = train_stats return MapTrainingStats(agent_id_to_stats)