diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 04505ace3..03b3d9849 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -190,7 +190,7 @@ def alloc_by_keys_diff( if key in meta.get_keys(): if isinstance(meta[key], Batch) and isinstance(batch[key], Batch): alloc_by_keys_diff(meta[key], batch[key], size, stack) - elif isinstance(meta[key], Batch) and meta[key].is_empty(): + elif isinstance(meta[key], Batch) and len(meta[key].get_keys()) == 0: meta[key] = create_value(batch[key], size, stack) else: meta[key] = create_value(batch[key], size, stack) @@ -768,7 +768,6 @@ def cat_(self, batches: BatchProtocol | Sequence[dict | BatchProtocol]) -> None: if len(batch) > 0: batch_list.append(Batch(batch)) elif isinstance(batch, Batch): - # x.is_empty() means that x is Batch() and should be ignored if len(batch.get_keys()) != 0: batch_list.append(batch) else: @@ -777,7 +776,7 @@ def cat_(self, batches: BatchProtocol | Sequence[dict | BatchProtocol]) -> None: return batches = batch_list try: - # x.is_empty(recurse=True) here means x is a nested empty batch + # len(batch) here means batch is a nested empty batch # like Batch(a=Batch), and we have to treat it as length zero and # keep it. lens = [0 if len(batch) == 0 else len(batch) for batch in batches] @@ -806,7 +805,6 @@ def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None if len(batch) > 0: batch_list.append(Batch(batch)) elif isinstance(batch, Batch): - # x.is_empty() means that x is Batch() and should be ignored if len(batch.get_keys()) != 0: batch_list.append(batch) else: @@ -821,7 +819,7 @@ def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None { batch_key for batch_key, obj in batch.items() - if not (isinstance(obj, BatchProtocol) and obj.is_empty()) + if not (isinstance(obj, BatchProtocol) and len(obj.get_keys()) == 0) } for batch in batches ] @@ -867,7 +865,7 @@ def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None # TODO: fix code/annotations s.t. the ignores can be removed if ( isinstance(value, BatchProtocol) # type: ignore - and value.is_empty() # type: ignore + and len(value.get_keys()) == 0 # type: ignore ): continue # type: ignore try: