From a006e7c4207d6eb7810b9b2bb0a0a5081d3a7454 Mon Sep 17 00:00:00 2001 From: Daniel Plop <1534513+dantp-ai@users.noreply.github.com> Date: Fri, 2 Aug 2024 21:59:37 +0200 Subject: [PATCH] Bugfix/batch eq for scalar (#1186) Fixes: https://github.com/thu-ml/tianshou/issues/1182 Note: Updated `test_batch.test_slice_distribution()` to use allclose (See: https://github.com/thu-ml/tianshou/pull/1181). --- CHANGELOG.md | 2 ++ test/base/test_batch.py | 31 +++++++++++++++++++++++++++---- tianshou/data/batch.py | 10 ++++++++++ 3 files changed, 39 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d20489a6d..841d05b9c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -62,6 +62,7 @@ - Fix `output_dim` not being set if `features_only`=True and `output_dim_added_layer` is not None #1128 - `PPOPolicy`: - Fix `max_batchsize` not being used in `logp_old` computation inside `process_fn` #1168 +- Fix `Batch.__eq__` to allow comparing Batches with scalar array values #1185 ### Internal Improvements - `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063 @@ -108,6 +109,7 @@ continuous and discrete cases. #1032 - Fixed env seeding it `test_sac_with_il.py` so that the test doesn't fail randomly. #1081 - Improved CI triggers and added telemetry (if requested by user) #1177 - Improved environment used in tests. +- Improved tests bach equality to check with scalar values #1185 ### Dependencies - [DeepDiff](https://github.com/seperman/deepdiff) added to help with diffs of batches in tests. #1098 diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 9accff86c..8839fe482 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -544,6 +544,28 @@ def test_nested_shapes_different() -> None: batch2 = Batch(a=Batch(a=[1, 4]), b=[4, 5]) assert batch1 != batch2 + @staticmethod + def test_array_scalars() -> None: + batch1 = Batch(a={"b": 1}) + batch2 = Batch(a={"b": 1}) + assert batch1 == batch2 + + batch3 = Batch(a={"c": 2}) + assert batch1 != batch3 + + batch4 = Batch(b={"b": 1}) + assert batch1 != batch4 + + batch5 = Batch(a={"b": 10}) + assert batch1 != batch5 + + batch6 = Batch(a={"b": [1]}) + assert batch1 == batch6 + + batch7 = Batch(a=1, b=5) + batch8 = Batch(a=1, b=5) + assert batch7 == batch8 + @staticmethod def test_slice_equal() -> None: batch1 = Batch(a=[1, 2, 3]) @@ -837,10 +859,11 @@ def test_slice_distribution() -> None: selected_idx = [1, 3] sliced_batch = batch[selected_idx] sliced_probs = cat_probs[selected_idx] - assert (sliced_batch.dist.probs == Categorical(probs=sliced_probs).probs).all() - assert ( - Categorical(probs=sliced_probs).probs == get_sliced_dist(dist, selected_idx).probs - ).all() + assert torch.allclose(sliced_batch.dist.probs, Categorical(probs=sliced_probs).probs) + assert torch.allclose( + Categorical(probs=sliced_probs).probs, + get_sliced_dist(dist, selected_idx).probs, + ) # retrieving a single index assert torch.allclose(batch[0].dist.probs, dist.probs[0]) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 650a5ccdf..c64b0fc75 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -678,6 +678,16 @@ def __eq__(self, other: Any) -> bool: this_batch_no_torch_tensor: Batch = Batch.to_numpy(self) other_batch_no_torch_tensor: Batch = Batch.to_numpy(other) + # DeepDiff 7.0.1 cannot compare 0-dimensional arrays + # so, we ensure with this transform that all array values have at least 1 dim + this_batch_no_torch_tensor.apply_values_transform( + values_transform=np.atleast_1d, + inplace=True, + ) + other_batch_no_torch_tensor.apply_values_transform( + values_transform=np.atleast_1d, + inplace=True, + ) this_dict = this_batch_no_torch_tensor.to_dict(recursive=True) other_dict = other_batch_no_torch_tensor.to_dict(recursive=True)