diff --git a/pettingzoo/utils/wrappers/terminate_illegal.py b/pettingzoo/utils/wrappers/terminate_illegal.py index a49d9a0be..79f95504a 100644 --- a/pettingzoo/utils/wrappers/terminate_illegal.py +++ b/pettingzoo/utils/wrappers/terminate_illegal.py @@ -1,4 +1,3 @@ -# pyright reportGeneralTypeIssues=false from __future__ import annotations from pettingzoo.utils.env import ActionType, AECEnv, AgentID, ObsType @@ -20,6 +19,7 @@ def __init__( self._illegal_value = illegal_reward self._prev_obs = None self._prev_info = None + self._terminated = False # terminated by an illegal move def reset(self, seed: int | None = None, options: dict | None = None) -> None: self._terminated = False @@ -42,7 +42,6 @@ def step(self, action: ActionType) -> None: if self._prev_obs is None: self.observe(self.agent_selection) if isinstance(self._prev_obs, dict): - assert self._prev_obs is not None assert ( "action_mask" in self._prev_obs ), f"`action_mask` not found in dictionary observation: {self._prev_obs}. Action mask must either be in `observation['action_mask']` or `info['action_mask']` to use TerminateIllegalWrapper." @@ -60,7 +59,7 @@ def step(self, action: ActionType) -> None: self.terminations[self.agent_selection] or self.truncations[self.agent_selection] ): - self._was_dead_step(action) # pyright: ignore[reportGeneralTypeIssues] + self.env.unwrapped._was_dead_step(action) elif ( not self.terminations[self.agent_selection] and not self.truncations[self.agent_selection] @@ -70,12 +69,10 @@ def step(self, action: ActionType) -> None: self.env.unwrapped._cumulative_rewards[self.agent_selection] = 0 self.env.unwrapped.terminations = {d: True for d in self.agents} self.env.unwrapped.truncations = {d: True for d in self.agents} - self._prev_obs = None - self._prev_info = None self.env.unwrapped.rewards = {d: 0 for d in self.truncations} self.env.unwrapped.rewards[current_agent] = float(self._illegal_value) - self._accumulate_rewards() - self._deads_step_first() + self.env.unwrapped._accumulate_rewards() + self.env.unwrapped._deads_step_first() self._terminated = True else: super().step(action) diff --git a/test/wrapper_test.py b/test/wrapper_test.py index 650fe328b..a03bd81b3 100644 --- a/test/wrapper_test.py +++ b/test/wrapper_test.py @@ -3,8 +3,13 @@ import pytest from pettingzoo.butterfly import pistonball_v6 -from pettingzoo.classic import texas_holdem_no_limit_v6 -from pettingzoo.utils.wrappers import MultiEpisodeEnv, MultiEpisodeParallelEnv +from pettingzoo.classic import texas_holdem_no_limit_v6, tictactoe_v3 +from pettingzoo.utils.wrappers import ( + BaseWrapper, + MultiEpisodeEnv, + MultiEpisodeParallelEnv, + TerminateIllegalWrapper, +) @pytest.mark.parametrize(("num_episodes"), [1, 2, 3, 4, 5, 6]) @@ -67,3 +72,65 @@ def test_multi_episode_parallel_env_wrapper(num_episodes) -> None: assert ( steps == num_episodes * 125 ), f"Expected to have 125 steps per episode, got {steps / num_episodes}." + + +def _do_game(env: TerminateIllegalWrapper, seed: int) -> None: + """Run a single game with reproducible random moves.""" + assert isinstance( + env, TerminateIllegalWrapper + ), "test_terminate_illegal must use TerminateIllegalWrapper" + env.reset(seed) + for agent in env.agents: + # make the random moves reproducible + env.action_space(agent).seed(seed) + + for agent in env.agent_iter(): + _, _, termination, truncation, _ = env.last() + + if termination or truncation: + env.step(None) + else: + action = env.action_space(agent).sample() + env.step(action) + + +def test_terminate_illegal() -> None: + """Test for a problem with terminate illegal wrapper. + + The problem is that env variables, including agent_selection, are set by + calls from TerminateIllegalWrapper to env functions. However, they are + called by the wrapper object, not the env so they are set in the wrapper + object rather than the base env object. When the code later tries to run, + the values get updated in the env code, but the wrapper pulls it's own + values that shadow them. + + The test here confirms that is fixed. + """ + # not using env() because we need to ensure that the env is + # wrapped by TerminateIllegalWrapper + raw_env = tictactoe_v3.raw_env() + env = TerminateIllegalWrapper(raw_env, illegal_reward=-1) + + _do_game(env, 42) + # bug is triggered by a corrupted state after a game is terminated + # due to an illegal move. So we need to run the game twice to + # see the effect. + _do_game(env, 42) + + # get a list of what all the agent_selection values in the wrapper stack + unwrapped = env + agent_selections = [] + while unwrapped != env.unwrapped: + # the actual value for this wrapper (or None if no value) + agent_selections.append(unwrapped.__dict__.get("agent_selection", None)) + assert isinstance(unwrapped, BaseWrapper) + unwrapped = unwrapped.env + + # last one from the actual env + agent_selections.append(unwrapped.__dict__.get("agent_selection", None)) + + # remove None from agent_selections + agent_selections = [x for x in agent_selections if x is not None] + + # all values must be the same, or else the wrapper and env are mismatched + assert len(set(agent_selections)) == 1, "agent_selection mismatch"