Skip to content

Commit

Permalink
Allow two (same/different) Batch objs to be tested for equality (thu-…
Browse files Browse the repository at this point in the history
…ml#1098)

Closes: thu-ml#1086

### Api Extensions

- Batch received new method: `to_numpy_`. thu-ml#1098
- `to_dict` in Batch supports also non-recursive conversion. thu-ml#1098
- Batch `__eq__` now implemented, semantic equality check of batches is
now possible. thu-ml#1098

### Breaking Changes

- The method `to_numpy` in `data.utils.batch.Batch` is not in-place
anymore. Instead, a new method `to_numpy_` does the conversion in-place.
thu-ml#1098
  • Loading branch information
dantp-ai authored Apr 16, 2024
1 parent 049907d commit ca4f74f
Show file tree
Hide file tree
Showing 8 changed files with 245 additions and 15 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
- Trainers can control whether collectors should be reset prior to training. #1063
- Convenience constructor for `CollectStats` called `with_autogenerated_stats`. #1063
- `SamplingConfig` supports `batch_size=None`. #1077
- Batch received new method: `to_numpy_`. #1098
- `to_dict` in Batch supports also non-recursive conversion. #1098
- Batch __eq__ now implemented, semantic equality check of batches is now possible. #1098

### Internal Improvements
- `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063
Expand All @@ -34,6 +37,7 @@ expicitly or pass `reset_before_collect=True` . #1063
- Changed interface of `dist_fn` in `PGPolicy` and all subclasses to take a single argument in both
continuous and discrete cases. #1032
- `utils.net.common.Recurrent` now receives and returns a `RecurrentStateBatch` instead of a dict. #1077
- The method `to_numpy` in `data.utils.batch.Batch` is not in-place anymore. Instead, a new method `to_numpy_` does the conversion in-place. #1098

### Tests
- Fixed env seeding it test_sac_with_il.py so that the test doesn't fail randomly. #1081
Expand Down
4 changes: 2 additions & 2 deletions docs/01_tutorials/03_batch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -485,8 +485,8 @@ Miscellaneous Notes
tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]])
>>> # data.to_numpy is also available
>>> data.to_numpy()
>>> # data.to_numpy_ is also available
>>> data.to_numpy_()

.. raw:: html

Expand Down
2 changes: 1 addition & 1 deletion docs/02_notebooks/L1_Batch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@
},
"outputs": [],
"source": [
"batch_cat.to_numpy()\n",
"batch_cat.to_numpy_()\n",
"print(batch_cat)\n",
"batch_cat.to_torch()\n",
"print(batch_cat)"
Expand Down
35 changes: 33 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ exclude = ["test/*", "examples/*", "docs/*"]

[tool.poetry.dependencies]
python = "^3.11"
deepdiff = "^7.0.1"
gymnasium = "^0.28.0"
h5py = "^3.9.0"
numba = "^0.57.1"
Expand Down
166 changes: 164 additions & 2 deletions test/base/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import pickle
import sys
from itertools import starmap
from typing import cast
from typing import Any, cast

import networkx as nx
import numpy as np
import pytest
import torch
from deepdiff import DeepDiff

from tianshou.data import Batch, to_numpy, to_torch

Expand Down Expand Up @@ -477,7 +478,7 @@ def test_batch_from_to_numpy_without_copy() -> None:
a_mem_addr_orig = batch.a.__array_interface__["data"][0]
c_mem_addr_orig = batch.b.c.__array_interface__["data"][0]
batch.to_torch()
batch.to_numpy()
batch.to_numpy_()
a_mem_addr_new = batch.a.__array_interface__["data"][0]
c_mem_addr_new = batch.b.c.__array_interface__["data"][0]
assert a_mem_addr_new == a_mem_addr_orig
Expand Down Expand Up @@ -565,6 +566,167 @@ def test_batch_standard_compatibility() -> None:
Batch()[0]


class TestBatchEquality:
@staticmethod
def test_keys_different() -> None:
batch1 = Batch(a=[1, 2], b=[100, 50])
batch2 = Batch(b=[1, 2], c=[100, 50])
assert batch1 != batch2

@staticmethod
def test_keys_missing() -> None:
batch1 = Batch(a=[1, 2], b=[2, 3, 4])
batch2 = Batch(a=[1, 2], b=[2, 3, 4])
batch2.pop("b")
assert batch1 != batch2

@staticmethod
def test_types_keys_different() -> None:
batch1 = Batch(a=[1, 2, 3], b=[4, 5])
batch2 = Batch(a=[1, 2, 3], b=Batch(a=[4, 5]))
assert batch1 != batch2

@staticmethod
def test_array_types_different() -> None:
batch1 = Batch(a=[1, 2, 3], b=np.array([4, 5]))
batch2 = Batch(a=[1, 2, 3], b=torch.Tensor([4, 5]))
assert batch1 != batch2

@staticmethod
def test_nested_values_different() -> None:
batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5])
batch2 = Batch(a=Batch(a=[1, 2, 4]), b=[4, 5])
assert batch1 != batch2

@staticmethod
def test_nested_shapes_different() -> None:
batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5])
batch2 = Batch(a=Batch(a=[1, 4]), b=[4, 5])
assert batch1 != batch2

@staticmethod
def test_slice_equal() -> None:
batch1 = Batch(a=[1, 2, 3])
assert batch1[:2] == batch1[:2]

@staticmethod
def test_slice_ellipsis_equal() -> None:
batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5], c=[100, 1001, 2000])
assert batch1[..., 1:] == batch1[..., 1:]

@staticmethod
def test_empty_batches() -> None:
assert Batch() == Batch()

@staticmethod
def test_different_order_keys() -> None:
assert Batch(a=1, b=2) == Batch(b=2, a=1)

@staticmethod
def test_tuple_and_list_types() -> None:
assert Batch(a=(1, 2)) == Batch(a=[1, 2])

@staticmethod
def test_subbatch_dict_and_batch_types() -> None:
assert Batch(a={"x": 1}) == Batch(a=Batch(x=1))


class TestBatchToDict:
@staticmethod
def test_to_dict_empty_batch_no_recurse() -> None:
batch = Batch()
expected: dict[Any, Any] = {}
assert batch.to_dict() == expected

@staticmethod
def test_to_dict_with_simple_values_recurse() -> None:
batch = Batch(a=1, b="two", c=np.array([3, 4]))
expected = {"a": np.asanyarray(1), "b": "two", "c": np.array([3, 4])}
assert not DeepDiff(batch.to_dict(recursive=True), expected)

@staticmethod
def test_to_dict_simple() -> None:
batch = Batch(a=1, b="two")
expected = {"a": np.asanyarray(1), "b": "two"}
assert batch.to_dict() == expected

@staticmethod
def test_to_dict_nested_batch_no_recurse() -> None:
nested_batch = Batch(c=3)
batch = Batch(a=1, b=nested_batch)
expected = {"a": np.asanyarray(1), "b": nested_batch}
assert not DeepDiff(batch.to_dict(recursive=False), expected)

@staticmethod
def test_to_dict_nested_batch_recurse() -> None:
nested_batch = Batch(c=3)
batch = Batch(a=1, b=nested_batch)
expected = {"a": np.asanyarray(1), "b": {"c": np.asanyarray(3)}}
assert not DeepDiff(batch.to_dict(recursive=True), expected)

@staticmethod
def test_to_dict_multiple_nested_batch_recurse() -> None:
nested_batch = Batch(c=Batch(e=3), d=[100, 200, 300])
batch = Batch(a=1, b=nested_batch)
expected = {
"a": np.asanyarray(1),
"b": {"c": {"e": np.asanyarray(3)}, "d": np.array([100, 200, 300])},
}
assert not DeepDiff(batch.to_dict(recursive=True), expected)

@staticmethod
def test_to_dict_array() -> None:
batch = Batch(a=np.array([1, 2, 3]))
expected = {"a": np.array([1, 2, 3])}
assert not DeepDiff(batch.to_dict(), expected)

@staticmethod
def test_to_dict_nested_batch_with_array() -> None:
nested_batch = Batch(c=np.array([4, 5]))
batch = Batch(a=1, b=nested_batch)
expected = {"a": np.asanyarray(1), "b": {"c": np.array([4, 5])}}
assert not DeepDiff(batch.to_dict(recursive=True), expected)

@staticmethod
def test_to_dict_torch_tensor() -> None:
t1 = torch.tensor([1.0, 2.0]).detach().cpu().numpy()
batch = Batch(a=t1)
t2 = torch.tensor([1.0, 2.0]).detach().cpu().numpy()
expected = {"a": t2}
assert not DeepDiff(batch.to_dict(), expected)

@staticmethod
def test_to_dict_nested_batch_with_torch_tensor() -> None:
nested_batch = Batch(c=torch.tensor([4, 5]).detach().cpu().numpy())
batch = Batch(a=1, b=nested_batch)
expected = {"a": np.asanyarray(1), "b": {"c": torch.tensor([4, 5]).detach().cpu().numpy()}}
assert not DeepDiff(batch.to_dict(recursive=True), expected)


class TestToNumpy:
"""Tests for `Batch.to_numpy()` and its in-place counterpart `Batch.to_numpy_()` ."""

@staticmethod
def test_to_numpy() -> None:
batch = Batch(a=1, b=torch.arange(5), c={"d": torch.tensor([1, 2, 3])})
new_batch: Batch = Batch.to_numpy(batch)
assert id(batch) != id(new_batch)
assert isinstance(batch.b, torch.Tensor)
assert isinstance(batch.c.d, torch.Tensor)

assert isinstance(new_batch.b, np.ndarray)
assert isinstance(new_batch.c.d, np.ndarray)

@staticmethod
def test_to_numpy_() -> None:
batch = Batch(a=1, b=torch.arange(5), c={"d": torch.tensor([1, 2, 3])})
id_batch = id(batch)
batch.to_numpy_()
assert id_batch == id(batch)
assert isinstance(batch.b, np.ndarray)
assert isinstance(batch.c.d, np.ndarray)


if __name__ == "__main__":
test_batch()
test_batch_over_batch()
Expand Down
46 changes: 39 additions & 7 deletions tianshou/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import numpy as np
import torch
from deepdiff import DeepDiff

_SingleIndexType = slice | int | EllipsisType
IndexType = np.ndarray | _SingleIndexType | list[_SingleIndexType] | tuple[_SingleIndexType, ...]
Expand Down Expand Up @@ -268,7 +269,15 @@ def __repr__(self) -> str:
def __iter__(self) -> Iterator[Self]:
...

def to_numpy(self) -> None:
def __eq__(self, other: Any) -> bool:
...

@staticmethod
def to_numpy(batch: TBatch) -> TBatch:
"""Change all torch.Tensor to numpy.ndarray and return a new Batch."""
...

def to_numpy_(self) -> None:
"""Change all torch.Tensor to numpy.ndarray in-place."""
...

Expand Down Expand Up @@ -396,7 +405,7 @@ def split(
"""
...

def to_dict(self) -> dict[str, Any]:
def to_dict(self, recurse: bool = True) -> dict[str, Any]:
...

def to_list_of_dicts(self) -> list[dict[str, Any]]:
Expand Down Expand Up @@ -433,11 +442,11 @@ def __init__(
# Feels like kwargs could be just merged into batch_dict in the beginning
self.__init__(kwargs, copy=copy) # type: ignore

def to_dict(self) -> dict[str, Any]:
def to_dict(self, recursive: bool = True) -> dict[str, Any]:
result = {}
for k, v in self.__dict__.items():
if isinstance(v, Batch):
v = v.to_dict()
if recursive and isinstance(v, Batch):
v = v.to_dict(recursive=recursive)
result[k] = v
return result

Expand Down Expand Up @@ -503,6 +512,17 @@ def __getitem__(self, index: str | IndexType) -> Any:
return new_batch
raise IndexError("Cannot access item from empty Batch object.")

def __eq__(self, other: Any) -> bool:
if not isinstance(other, self.__class__):
return False

this_batch_no_torch_tensor: Batch = Batch.to_numpy(self)
other_batch_no_torch_tensor: Batch = Batch.to_numpy(other)
this_dict = this_batch_no_torch_tensor.to_dict(recursive=True)
other_dict = other_batch_no_torch_tensor.to_dict(recursive=True)

return not DeepDiff(this_dict, other_dict)

def __iter__(self) -> Iterator[Self]:
# TODO: empty batch raises an error on len and needs separate treatment, that's probably not a good idea
if len(self.__dict__) == 0:
Expand Down Expand Up @@ -602,12 +622,24 @@ def __repr__(self) -> str:
self_str = self.__class__.__name__ + "()"
return self_str

def to_numpy(self) -> None:
@staticmethod
def to_numpy(batch: TBatch) -> TBatch:
batch_dict = deepcopy(batch)
for batch_key, obj in batch_dict.items():
if isinstance(obj, torch.Tensor):
batch_dict.__dict__[batch_key] = obj.detach().cpu().numpy()
elif isinstance(obj, Batch):
obj = Batch.to_numpy(obj)
batch_dict.__dict__[batch_key] = obj

return batch_dict

def to_numpy_(self) -> None:
for batch_key, obj in self.items():
if isinstance(obj, torch.Tensor):
self.__dict__[batch_key] = obj.detach().cpu().numpy()
elif isinstance(obj, Batch):
obj.to_numpy()
obj.to_numpy_()

def to_torch(
self,
Expand Down
Loading

0 comments on commit ca4f74f

Please sign in to comment.