Skip to content

Commit

Permalink
Minor: use Self type where appropriate (#942)
Browse files Browse the repository at this point in the history
Small typing improvement, related to
#915 (comment)
  • Loading branch information
MischaPanch authored Sep 19, 2023
1 parent 2cc34fb commit c8e7d02
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 40 deletions.
51 changes: 26 additions & 25 deletions tianshou/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import (
Any,
Protocol,
Self,
TypeVar,
Union,
cast,
Expand Down Expand Up @@ -232,7 +233,7 @@ def __getitem__(self, index: str) -> Any:
...

@overload
def __getitem__(self: TBatch, index: IndexType) -> TBatch:
def __getitem__(self, index: IndexType) -> Self:
...

def __getitem__(self, index: str | IndexType) -> Any:
Expand All @@ -241,22 +242,22 @@ def __getitem__(self, index: str | IndexType) -> Any:
def __setitem__(self, index: str | IndexType, value: Any) -> None:
...

def __iadd__(self: TBatch, other: TBatch | Number | np.number) -> TBatch:
def __iadd__(self, other: Self | Number | np.number) -> Self:
...

def __add__(self: TBatch, other: TBatch | Number | np.number) -> TBatch:
def __add__(self, other: Self | Number | np.number) -> Self:
...

def __imul__(self: TBatch, value: Number | np.number) -> TBatch:
def __imul__(self, value: Number | np.number) -> Self:
...

def __mul__(self: TBatch, value: Number | np.number) -> TBatch:
def __mul__(self, value: Number | np.number) -> Self:
...

def __itruediv__(self: TBatch, value: Number | np.number) -> TBatch:
def __itruediv__(self, value: Number | np.number) -> Self:
...

def __truediv__(self: TBatch, value: Number | np.number) -> TBatch:
def __truediv__(self, value: Number | np.number) -> Self:
...

def __repr__(self) -> str:
Expand All @@ -274,7 +275,7 @@ def to_torch(
"""Change all numpy.ndarray to torch.Tensor in-place."""
...

def cat_(self, batches: TBatch | Sequence[dict | TBatch]) -> None:
def cat_(self, batches: Self | Sequence[dict | Self]) -> None:
"""Concatenate a list of (or one) Batch objects into current batch."""
...

Expand All @@ -298,7 +299,7 @@ def cat(batches: Sequence[dict | TBatch]) -> TBatch:
"""
...

def stack_(self, batches: Sequence[dict | TBatch], axis: int = 0) -> None:
def stack_(self, batches: Sequence[dict | Self], axis: int = 0) -> None:
"""Stack a list of Batch object into current batch."""
...

Expand Down Expand Up @@ -327,7 +328,7 @@ def stack(batches: Sequence[dict | TBatch], axis: int = 0) -> TBatch:
"""
...

def empty_(self: TBatch, index: slice | IndexType | None = None) -> TBatch:
def empty_(self, index: slice | IndexType | None = None) -> Self:
"""Return an empty Batch object with 0 or None filled.
If "index" is specified, it will only reset the specific indexed-data.
Expand Down Expand Up @@ -362,7 +363,7 @@ def empty(batch: TBatch, index: IndexType | None = None) -> TBatch:
"""
...

def update(self, batch: dict | TBatch | None = None, **kwargs: Any) -> None:
def update(self, batch: dict | Self | None = None, **kwargs: Any) -> None:
"""Update this batch from another dict/Batch."""
...

Expand All @@ -373,11 +374,11 @@ def is_empty(self, recurse: bool = False) -> bool:
...

def split(
self: TBatch,
self,
size: int,
shuffle: bool = True,
merge_last: bool = False,
) -> Iterator[TBatch]:
) -> Iterator[Self]:
"""Split whole data into multiple small batches.
:param int size: divide the data batch with the given size, but one
Expand Down Expand Up @@ -457,7 +458,7 @@ def __getitem__(self, index: str) -> Any:
...

@overload
def __getitem__(self: TBatch, index: IndexType) -> TBatch:
def __getitem__(self, index: IndexType) -> Self:
...

def __getitem__(self, index: str | IndexType) -> Any:
Expand Down Expand Up @@ -501,7 +502,7 @@ def __setitem__(self, index: str | IndexType, value: Any) -> None:
else:
self.__dict__[key][index] = None

def __iadd__(self: TBatch, other: TBatch | Number | np.number) -> TBatch:
def __iadd__(self, other: Self | Number | np.number) -> Self:
"""Algebraic addition with another Batch instance in-place."""
if isinstance(other, Batch):
for (batch_key, obj), value in zip(
Expand All @@ -521,11 +522,11 @@ def __iadd__(self: TBatch, other: TBatch | Number | np.number) -> TBatch:
return self
raise TypeError("Only addition of Batch or number is supported.")

def __add__(self: TBatch, other: TBatch | Number | np.number) -> TBatch:
def __add__(self, other: Self | Number | np.number) -> Self:
"""Algebraic addition with another Batch instance out-of-place."""
return deepcopy(self).__iadd__(other)

def __imul__(self: TBatch, value: Number | np.number) -> TBatch:
def __imul__(self, value: Number | np.number) -> Self:
"""Algebraic multiplication with a scalar value in-place."""
assert _is_number(value), "Only multiplication by a number is supported."
for batch_key, obj in self.__dict__.items():
Expand All @@ -534,11 +535,11 @@ def __imul__(self: TBatch, value: Number | np.number) -> TBatch:
self.__dict__[batch_key] *= value
return self

def __mul__(self: TBatch, value: Number | np.number) -> TBatch:
def __mul__(self, value: Number | np.number) -> Self:
"""Algebraic multiplication with a scalar value out-of-place."""
return deepcopy(self).__imul__(value)

def __itruediv__(self: TBatch, value: Number | np.number) -> TBatch:
def __itruediv__(self, value: Number | np.number) -> Self:
"""Algebraic division with a scalar value in-place."""
assert _is_number(value), "Only division by a number is supported."
for batch_key, obj in self.__dict__.items():
Expand All @@ -547,7 +548,7 @@ def __itruediv__(self: TBatch, value: Number | np.number) -> TBatch:
self.__dict__[batch_key] /= value
return self

def __truediv__(self: TBatch, value: Number | np.number) -> TBatch:
def __truediv__(self, value: Number | np.number) -> Self:
"""Algebraic division with a scalar value out-of-place."""
return deepcopy(self).__itruediv__(value)

Expand Down Expand Up @@ -604,7 +605,7 @@ def to_torch(
obj = obj.type(dtype) # noqa: PLW2901
self.__dict__[batch_key] = obj

def __cat(self: TBatch, batches: Sequence[dict | TBatch], lens: list[int]) -> None:
def __cat(self, batches: Sequence[dict | Self], lens: list[int]) -> None:
"""Private method for Batch.cat_.
::
Expand Down Expand Up @@ -798,7 +799,7 @@ def stack(batches: Sequence[dict | TBatch], axis: int = 0) -> TBatch:
# can't cast to a generic type, so we have to ignore the type here
return batch # type: ignore

def empty_(self: TBatch, index: slice | IndexType | None = None) -> TBatch:
def empty_(self, index: slice | IndexType | None = None) -> Self:
for batch_key, obj in self.items():
if isinstance(obj, torch.Tensor): # most often case
self.__dict__[batch_key][index] = 0
Expand Down Expand Up @@ -826,7 +827,7 @@ def empty_(self: TBatch, index: slice | IndexType | None = None) -> TBatch:
def empty(batch: TBatch, index: IndexType | None = None) -> TBatch:
return deepcopy(batch).empty_(index)

def update(self, batch: dict | TBatch | None = None, **kwargs: Any) -> None:
def update(self, batch: dict | Self | None = None, **kwargs: Any) -> None:
if batch is None:
self.update(kwargs)
return
Expand Down Expand Up @@ -902,11 +903,11 @@ def shape(self) -> list[int]:
)

def split(
self: TBatch,
self,
size: int,
shuffle: bool = True,
merge_last: bool = False,
) -> Iterator[TBatch]:
) -> Iterator[Self]:
length = len(self)
if size == -1:
size = length
Expand Down
6 changes: 3 additions & 3 deletions tianshou/data/buffer/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, cast
from typing import Any, Self, cast

import h5py
import numpy as np
Expand Down Expand Up @@ -111,7 +111,7 @@ def save_hdf5(self, path: str, compression: str | None = None) -> None:
to_hdf5(self.__dict__, f, compression=compression)

@classmethod
def load_hdf5(cls, path: str, device: str | None = None) -> "ReplayBuffer":
def load_hdf5(cls, path: str, device: str | None = None) -> Self:
"""Load replay buffer from HDF5 file."""
with h5py.File(path, "r") as f:
buf = cls.__new__(cls)
Expand All @@ -128,7 +128,7 @@ def from_data(
truncated: h5py.Dataset,
done: h5py.Dataset,
obs_next: h5py.Dataset,
) -> "ReplayBuffer":
) -> Self:
size = len(obs)
assert all(
len(dset) == size for dset in [obs, act, rew, terminated, truncated, done, obs_next]
Expand Down
6 changes: 3 additions & 3 deletions tianshou/data/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@ def __init__(
policy: BasePolicy,
env: gym.Env | BaseVectorEnv,
buffer: ReplayBuffer | None = None,
preprocess_fn: Callable[..., Batch] | None = None,
preprocess_fn: Callable[..., RolloutBatchProtocol] | None = None,
exploration_noise: bool = False,
) -> None:
super().__init__()
if isinstance(env, gym.Env) and not hasattr(env, "__len__"):
warnings.warn("Single environment detected, wrap to DummyVectorEnv.")
self.env = DummyVectorEnv([lambda: env])
self.env = DummyVectorEnv([lambda: env]) # type: ignore
else:
self.env = env # type: ignore
self.env_num = len(self.env)
Expand Down Expand Up @@ -413,7 +413,7 @@ def __init__(
policy: BasePolicy,
env: BaseVectorEnv,
buffer: ReplayBuffer | None = None,
preprocess_fn: Callable[..., Batch] | None = None,
preprocess_fn: Callable[..., RolloutBatchProtocol] | None = None,
exploration_noise: bool = False,
) -> None:
# assert env.is_async
Expand Down
9 changes: 2 additions & 7 deletions tianshou/env/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
from typing import TYPE_CHECKING, Any, Union
from typing import Any

import cloudpickle
import gymnasium
import numpy as np

from tianshou.env.pettingzoo_env import PettingZooEnv

if TYPE_CHECKING:
import gym

# TODO: remove gym entirely? Currently mypy complains in several places
# if gym.Env is removed from the Union
ENV_TYPE = Union[gymnasium.Env, "gym.Env", PettingZooEnv]
ENV_TYPE = gymnasium.Env | PettingZooEnv

gym_new_venv_step_type = tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]

Expand Down
4 changes: 2 additions & 2 deletions tianshou/policy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch import nn

from tianshou.data import ReplayBuffer, to_numpy, to_torch_as
from tianshou.data.batch import BatchProtocol, TBatch
from tianshou.data.batch import BatchProtocol
from tianshou.data.types import BatchWithReturnsProtocol, RolloutBatchProtocol
from tianshou.utils import MultipleLRSchedulers

Expand Down Expand Up @@ -185,7 +185,7 @@ def forward(
"""

@overload
def map_action(self, act: TBatch) -> TBatch:
def map_action(self, act: BatchProtocol) -> BatchProtocol:
...

@overload
Expand Down

0 comments on commit c8e7d02

Please sign in to comment.