Skip to content

Commit

Permalink
Use DeepDiff to test for Batch equality
Browse files Browse the repository at this point in the history
  * Note: `Batch.to_numpy()` should be extended to support also a non in-place operation.
  • Loading branch information
dantp-ai committed Apr 9, 2024
1 parent 68050c0 commit 98d611c
Showing 1 changed file with 6 additions and 26 deletions.
32 changes: 6 additions & 26 deletions tianshou/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import numpy as np
import torch
from deepdiff import DeepDiff

_SingleIndexType = slice | int | EllipsisType
IndexType = np.ndarray | _SingleIndexType | list[_SingleIndexType] | tuple[_SingleIndexType, ...]
Expand Down Expand Up @@ -507,33 +508,12 @@ def __eq__(self, other: Any) -> bool:
if not isinstance(other, self.__class__):
return False

this_dict = self.__dict__
other_dict = other.__dict__
self.to_numpy()
other.to_numpy()
this_dict = self.to_dict(recurse=True)
other_dict = other.to_dict(recurse=True)

if len(this_dict) != len(other_dict):
return False
for batch_key, obs in this_dict.items():
if batch_key not in other_dict:
return False

other_val = other.__dict__[batch_key]

if batch_key in other_dict:
if isinstance(obs, Batch) and isinstance(other_val, Batch):
if not obs == other_val:
return False
elif isinstance(obs, np.ndarray) and isinstance(other_val, np.ndarray):
if not np.all(np.equal(obs.shape, other_val.shape)):
return False
if not np.all(np.equal(obs, other_val)):
return False
elif isinstance(obs, torch.Tensor) and isinstance(other_val, torch.Tensor):
if not torch.equal(obs, other_val):
return False
else:
return False

return True
return not DeepDiff(this_dict, other_dict)

def __iter__(self) -> Iterator[Self]:
# TODO: empty batch raises an error on len and needs separate treatment, that's probably not a good idea
Expand Down

0 comments on commit 98d611c

Please sign in to comment.