Skip to content

Commit

Permalink
Fix type check in atari wrapper, solves thu-ml#1111
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Panchenko committed Apr 16, 2024
1 parent 60d1ba1 commit 049907d
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 15 deletions.
4 changes: 2 additions & 2 deletions examples/atari/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def get_args() -> argparse.Namespace:
return parser.parse_args()


def test_dqn(args: argparse.Namespace = get_args()) -> None:
def main(args: argparse.Namespace = get_args()) -> None:
env, train_envs, test_envs = make_atari_env(
args.task,
args.seed,
Expand Down Expand Up @@ -260,4 +260,4 @@ def watch() -> None:


if __name__ == "__main__":
test_dqn(get_args())
main(get_args())
32 changes: 19 additions & 13 deletions examples/atari/atari_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,20 @@ def _parse_reset_result(reset_result: tuple) -> tuple[tuple, dict, bool]:
return reset_result, {}, contains_info


def get_space_dtype(obs_space: gym.spaces.Box) -> type[np.floating] | type[np.integer]:
obs_space_dtype: type[np.integer] | type[np.floating]
if np.issubdtype(obs_space.dtype, np.integer):
obs_space_dtype = np.integer
elif np.issubdtype(obs_space.dtype, np.floating):
obs_space_dtype = np.floating
else:
raise TypeError(
f"Unsupported observation space dtype: {obs_space.dtype}. "
f"This might be a bug in tianshou or gymnasium, please report it!",
)
return obs_space_dtype


class NoopResetEnv(gym.Wrapper):
"""Sample initial states by taking random number of no-ops on reset.
Expand Down Expand Up @@ -199,12 +213,8 @@ def __init__(self, env: gym.Env) -> None:
super().__init__(env)
self.size = 84
obs_space = env.observation_space
obs_space_dtype: type[np.floating[Any]] | type[np.integer[Any]]
if np.issubdtype(type(obs_space.dtype), np.integer):
obs_space_dtype = np.integer
elif np.issubdtype(type(obs_space.dtype), np.floating):
obs_space_dtype = np.floating
assert isinstance(obs_space, gym.spaces.Box)
obs_space_dtype = get_space_dtype(obs_space)
self.observation_space = gym.spaces.Box(
low=np.min(obs_space.low),
high=np.max(obs_space.high),
Expand Down Expand Up @@ -273,15 +283,11 @@ def __init__(self, env: gym.Env, n_frames: int) -> None:
obs_space_shape = env.observation_space.shape
assert obs_space_shape is not None
shape = (n_frames, *obs_space_shape)
assert isinstance(env.observation_space, gym.spaces.Box)
obs_space_dtype: type[np.floating[Any]] | type[np.integer[Any]]
if np.issubdtype(type(obs_space.dtype), np.integer):
obs_space_dtype = np.integer
elif np.issubdtype(type(obs_space.dtype), np.floating):
obs_space_dtype = np.floating
assert isinstance(obs_space, gym.spaces.Box)
obs_space_dtype = get_space_dtype(obs_space)
self.observation_space = gym.spaces.Box(
low=np.min(env.observation_space.low),
high=np.max(env.observation_space.high),
low=np.min(obs_space.low),
high=np.max(obs_space.high),
shape=shape,
dtype=obs_space_dtype,
)
Expand Down

0 comments on commit 049907d

Please sign in to comment.