Skip to content

Commit

Permalink
Change argname from recurse -> recursive
Browse files Browse the repository at this point in the history
  • Loading branch information
dantp-ai committed Apr 16, 2024
1 parent def4f0d commit cebde45
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
12 changes: 6 additions & 6 deletions test/base/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ def test_to_dict_empty_batch_no_recurse() -> None:
def test_to_dict_with_simple_values_recurse() -> None:
batch = Batch(a=1, b="two", c=np.array([3, 4]))
expected = {"a": np.asanyarray(1), "b": "two", "c": np.array([3, 4])}
assert not DeepDiff(batch.to_dict(recurse=True), expected)
assert not DeepDiff(batch.to_dict(recursive=True), expected)

@staticmethod
def test_to_dict_simple() -> None:
Expand All @@ -655,14 +655,14 @@ def test_to_dict_nested_batch_no_recurse() -> None:
nested_batch = Batch(c=3)
batch = Batch(a=1, b=nested_batch)
expected = {"a": np.asanyarray(1), "b": nested_batch}
assert not DeepDiff(batch.to_dict(recurse=False), expected)
assert not DeepDiff(batch.to_dict(recursive=False), expected)

@staticmethod
def test_to_dict_nested_batch_recurse() -> None:
nested_batch = Batch(c=3)
batch = Batch(a=1, b=nested_batch)
expected = {"a": np.asanyarray(1), "b": {"c": np.asanyarray(3)}}
assert not DeepDiff(batch.to_dict(recurse=True), expected)
assert not DeepDiff(batch.to_dict(recursive=True), expected)

@staticmethod
def test_to_dict_multiple_nested_batch_recurse() -> None:
Expand All @@ -672,7 +672,7 @@ def test_to_dict_multiple_nested_batch_recurse() -> None:
"a": np.asanyarray(1),
"b": {"c": {"e": np.asanyarray(3)}, "d": np.array([100, 200, 300])},
}
assert not DeepDiff(batch.to_dict(recurse=True), expected)
assert not DeepDiff(batch.to_dict(recursive=True), expected)

@staticmethod
def test_to_dict_array() -> None:
Expand All @@ -685,7 +685,7 @@ def test_to_dict_nested_batch_with_array() -> None:
nested_batch = Batch(c=np.array([4, 5]))
batch = Batch(a=1, b=nested_batch)
expected = {"a": np.asanyarray(1), "b": {"c": np.array([4, 5])}}
assert not DeepDiff(batch.to_dict(recurse=True), expected)
assert not DeepDiff(batch.to_dict(recursive=True), expected)

@staticmethod
def test_to_dict_torch_tensor() -> None:
Expand All @@ -700,7 +700,7 @@ def test_to_dict_nested_batch_with_torch_tensor() -> None:
nested_batch = Batch(c=torch.tensor([4, 5]).detach().cpu().numpy())
batch = Batch(a=1, b=nested_batch)
expected = {"a": np.asanyarray(1), "b": {"c": torch.tensor([4, 5]).detach().cpu().numpy()}}
assert not DeepDiff(batch.to_dict(recurse=True), expected)
assert not DeepDiff(batch.to_dict(recursive=True), expected)


class TestToNumpy:
Expand Down
10 changes: 5 additions & 5 deletions tianshou/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,11 +442,11 @@ def __init__(
# Feels like kwargs could be just merged into batch_dict in the beginning
self.__init__(kwargs, copy=copy) # type: ignore

def to_dict(self, recurse: bool = True) -> dict[str, Any]:
def to_dict(self, recursive: bool = True) -> dict[str, Any]:
result = {}
for k, v in self.__dict__.items():
if recurse and isinstance(v, Batch):
v = v.to_dict(recurse=recurse)
if recursive and isinstance(v, Batch):
v = v.to_dict(recursive=recursive)
result[k] = v
return result

Expand Down Expand Up @@ -518,8 +518,8 @@ def __eq__(self, other: Any) -> bool:

this_batch_no_torch_tensor: Batch = Batch.to_numpy(self)
other_batch_no_torch_tensor: Batch = Batch.to_numpy(other)
this_dict = this_batch_no_torch_tensor.to_dict(recurse=True)
other_dict = other_batch_no_torch_tensor.to_dict(recurse=True)
this_dict = this_batch_no_torch_tensor.to_dict(recursive=True)
other_dict = other_batch_no_torch_tensor.to_dict(recursive=True)

return not DeepDiff(this_dict, other_dict)

Expand Down

0 comments on commit cebde45

Please sign in to comment.