diff --git a/test/base/test_batch.py b/test/base/test_batch.py index b45ea00c5..d1b459b76 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -655,7 +655,7 @@ def test_to_dict_nested_batch_no_recurse() -> None: nested_batch = Batch(c=3) batch = Batch(a=1, b=nested_batch) expected = {"a": np.asanyarray(1), "b": nested_batch} - assert not DeepDiff(batch.to_dict(), expected) + assert not DeepDiff(batch.to_dict(recurse=False), expected) @staticmethod def test_to_dict_nested_batch_recurse() -> None: diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 95b25f741..b09847a20 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -405,7 +405,7 @@ def split( """ ... - def to_dict(self, recurse: bool = False) -> dict[str, Any]: + def to_dict(self, recurse: bool = True) -> dict[str, Any]: ... def to_list_of_dicts(self) -> list[dict[str, Any]]: @@ -442,7 +442,7 @@ def __init__( # Feels like kwargs could be just merged into batch_dict in the beginning self.__init__(kwargs, copy=copy) # type: ignore - def to_dict(self, recurse: bool = False) -> dict[str, Any]: + def to_dict(self, recurse: bool = True) -> dict[str, Any]: result = {} for k, v in self.__dict__.items(): if recurse and isinstance(v, Batch):