From 7d7068d2c9e136466782b54f70b15e1065cc1776 Mon Sep 17 00:00:00 2001 From: David Ackerman <145808634+dm-ackerman@users.noreply.github.com> Date: Tue, 7 May 2024 13:43:18 +0000 Subject: [PATCH 1/6] Remove undefined value in TerminateIllegalWrapper --- pettingzoo/utils/wrappers/terminate_illegal.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pettingzoo/utils/wrappers/terminate_illegal.py b/pettingzoo/utils/wrappers/terminate_illegal.py index a49d9a0be..e0e9de0ac 100644 --- a/pettingzoo/utils/wrappers/terminate_illegal.py +++ b/pettingzoo/utils/wrappers/terminate_illegal.py @@ -20,6 +20,7 @@ def __init__( self._illegal_value = illegal_reward self._prev_obs = None self._prev_info = None + self._terminated = False # terminated by an illegal move def reset(self, seed: int | None = None, options: dict | None = None) -> None: self._terminated = False From 78ffc4a0260a7c0efc1eb2749089b25fde4be59a Mon Sep 17 00:00:00 2001 From: David Ackerman <145808634+dm-ackerman@users.noreply.github.com> Date: Tue, 7 May 2024 13:46:18 +0000 Subject: [PATCH 2/6] Remove redundant code in TerminateIllegalWrapper ``` if isinstance(self._prev_obs, dict): assert self._prev_obs is not None ``` `self._prev_obs` can't be `None` because it is a `dict` instance --- pettingzoo/utils/wrappers/terminate_illegal.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pettingzoo/utils/wrappers/terminate_illegal.py b/pettingzoo/utils/wrappers/terminate_illegal.py index e0e9de0ac..e8cec6a5e 100644 --- a/pettingzoo/utils/wrappers/terminate_illegal.py +++ b/pettingzoo/utils/wrappers/terminate_illegal.py @@ -43,7 +43,6 @@ def step(self, action: ActionType) -> None: if self._prev_obs is None: self.observe(self.agent_selection) if isinstance(self._prev_obs, dict): - assert self._prev_obs is not None assert ( "action_mask" in self._prev_obs ), f"`action_mask` not found in dictionary observation: {self._prev_obs}. Action mask must either be in `observation['action_mask']` or `info['action_mask']` to use TerminateIllegalWrapper." @@ -71,8 +70,6 @@ def step(self, action: ActionType) -> None: self.env.unwrapped._cumulative_rewards[self.agent_selection] = 0 self.env.unwrapped.terminations = {d: True for d in self.agents} self.env.unwrapped.truncations = {d: True for d in self.agents} - self._prev_obs = None - self._prev_info = None self.env.unwrapped.rewards = {d: 0 for d in self.truncations} self.env.unwrapped.rewards[current_agent] = float(self._illegal_value) self._accumulate_rewards() From 5a2ea12797abe5d5f3170325e5e1b7e5e58153cc Mon Sep 17 00:00:00 2001 From: David Ackerman <145808634+dm-ackerman@users.noreply.github.com> Date: Tue, 7 May 2024 13:54:44 +0000 Subject: [PATCH 3/6] Remove pyright flags in TerminateIllegalWrapper works fine without them --- pettingzoo/utils/wrappers/terminate_illegal.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pettingzoo/utils/wrappers/terminate_illegal.py b/pettingzoo/utils/wrappers/terminate_illegal.py index e8cec6a5e..550ac95ec 100644 --- a/pettingzoo/utils/wrappers/terminate_illegal.py +++ b/pettingzoo/utils/wrappers/terminate_illegal.py @@ -1,4 +1,3 @@ -# pyright reportGeneralTypeIssues=false from __future__ import annotations from pettingzoo.utils.env import ActionType, AECEnv, AgentID, ObsType @@ -60,7 +59,7 @@ def step(self, action: ActionType) -> None: self.terminations[self.agent_selection] or self.truncations[self.agent_selection] ): - self._was_dead_step(action) # pyright: ignore[reportGeneralTypeIssues] + self._was_dead_step(action) elif ( not self.terminations[self.agent_selection] and not self.truncations[self.agent_selection] From 2d3486b8d32de910ae280eae0271a62df21f36c8 Mon Sep 17 00:00:00 2001 From: David Ackerman <145808634+dm-ackerman@users.noreply.github.com> Date: Tue, 7 May 2024 15:42:53 +0000 Subject: [PATCH 4/6] Fix value setting bug in TerminateIllegalWrapper Functions need to be set from the actual env, not the wrapper. --- pettingzoo/utils/wrappers/terminate_illegal.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pettingzoo/utils/wrappers/terminate_illegal.py b/pettingzoo/utils/wrappers/terminate_illegal.py index 550ac95ec..79f95504a 100644 --- a/pettingzoo/utils/wrappers/terminate_illegal.py +++ b/pettingzoo/utils/wrappers/terminate_illegal.py @@ -59,7 +59,7 @@ def step(self, action: ActionType) -> None: self.terminations[self.agent_selection] or self.truncations[self.agent_selection] ): - self._was_dead_step(action) + self.env.unwrapped._was_dead_step(action) elif ( not self.terminations[self.agent_selection] and not self.truncations[self.agent_selection] @@ -71,8 +71,8 @@ def step(self, action: ActionType) -> None: self.env.unwrapped.truncations = {d: True for d in self.agents} self.env.unwrapped.rewards = {d: 0 for d in self.truncations} self.env.unwrapped.rewards[current_agent] = float(self._illegal_value) - self._accumulate_rewards() - self._deads_step_first() + self.env.unwrapped._accumulate_rewards() + self.env.unwrapped._deads_step_first() self._terminated = True else: super().step(action) From 63543af211e3c43ad1959e38dbfdc94271800eec Mon Sep 17 00:00:00 2001 From: David Ackerman <145808634+dm-ackerman@users.noreply.github.com> Date: Tue, 7 May 2024 15:55:14 +0000 Subject: [PATCH 5/6] Add test for TerminateIllegalWrapper bug --- test/terminite_illegal_bug_test.py | 71 ++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 test/terminite_illegal_bug_test.py diff --git a/test/terminite_illegal_bug_test.py b/test/terminite_illegal_bug_test.py new file mode 100644 index 000000000..be8f28c8e --- /dev/null +++ b/test/terminite_illegal_bug_test.py @@ -0,0 +1,71 @@ +"""Check for a problem with terminate illegal wrapper. + +The problem is that env variables, including agent_selection, are set by +calls from TerminateIllegalWrapper to env functions. However, they are +called by the wrapper object, not the env so they are set in the wrapper +object rather than the base env object. When the code later tries to run, +the values get updated in the env code, but the wrapper pulls it's own +values that shadow them. + +The test here confirms that is fixed. +""" + +from pettingzoo.classic import tictactoe_v3 +from pettingzoo.utils.wrappers import BaseWrapper, TerminateIllegalWrapper + + +def _do_game(env: TerminateIllegalWrapper, seed: int) -> None: + """Run a single game with reproducible random moves.""" + assert isinstance( + env, TerminateIllegalWrapper + ), "test_terminate_illegal must use TerminateIllegalWrapper" + env.reset(seed) + for agent in env.agents: + # make the random moves reproducible + env.action_space(agent).seed(seed) + + for agent in env.agent_iter(): + _, _, termination, truncation, _ = env.last() + + if termination or truncation: + env.step(None) + else: + action = env.action_space(agent).sample() + env.step(action) + + +def test_terminate_illegal() -> None: + """Test for error in TerminateIllegalWrapper. + + A bug caused TerminateIllegalWrapper to set values on the wrapper + rather than the environment. This tests for a recurrence of that + bug. + """ + # not using env() because we need to ensure that the env is + # wrapped by TerminateIllegalWrapper + raw_env = tictactoe_v3.raw_env() + env = TerminateIllegalWrapper(raw_env, illegal_reward=-1) + + _do_game(env, 42) + # bug is triggered by a corrupted state after a game is terminated + # due to an illegal move. So we need to run the game twice to + # see the effect. + _do_game(env, 42) + + # get a list of what all the agent_selection values in the wrapper stack + unwrapped = env + agent_selections = [] + while unwrapped != env.unwrapped: + # the actual value for this wrapper (or None if no value) + agent_selections.append(unwrapped.__dict__.get("agent_selection", None)) + assert isinstance(unwrapped, BaseWrapper) + unwrapped = unwrapped.env + + # last one from the actual env + agent_selections.append(unwrapped.__dict__.get("agent_selection", None)) + + # remove None from agent_selections + agent_selections = [x for x in agent_selections if x is not None] + + # all values must be the same, or else the wrapper and env are mismatched + assert len(set(agent_selections)) == 1, "agent_selection mismatch" From e891f00d2b507ca9a42bc7fc81b4d0d004a51391 Mon Sep 17 00:00:00 2001 From: David Ackerman <145808634+dm-ackerman@users.noreply.github.com> Date: Tue, 7 May 2024 21:41:24 +0000 Subject: [PATCH 6/6] Move wrapper test to file with other tests --- test/terminite_illegal_bug_test.py | 71 ------------------------------ test/wrapper_test.py | 71 +++++++++++++++++++++++++++++- 2 files changed, 69 insertions(+), 73 deletions(-) delete mode 100644 test/terminite_illegal_bug_test.py diff --git a/test/terminite_illegal_bug_test.py b/test/terminite_illegal_bug_test.py deleted file mode 100644 index be8f28c8e..000000000 --- a/test/terminite_illegal_bug_test.py +++ /dev/null @@ -1,71 +0,0 @@ -"""Check for a problem with terminate illegal wrapper. - -The problem is that env variables, including agent_selection, are set by -calls from TerminateIllegalWrapper to env functions. However, they are -called by the wrapper object, not the env so they are set in the wrapper -object rather than the base env object. When the code later tries to run, -the values get updated in the env code, but the wrapper pulls it's own -values that shadow them. - -The test here confirms that is fixed. -""" - -from pettingzoo.classic import tictactoe_v3 -from pettingzoo.utils.wrappers import BaseWrapper, TerminateIllegalWrapper - - -def _do_game(env: TerminateIllegalWrapper, seed: int) -> None: - """Run a single game with reproducible random moves.""" - assert isinstance( - env, TerminateIllegalWrapper - ), "test_terminate_illegal must use TerminateIllegalWrapper" - env.reset(seed) - for agent in env.agents: - # make the random moves reproducible - env.action_space(agent).seed(seed) - - for agent in env.agent_iter(): - _, _, termination, truncation, _ = env.last() - - if termination or truncation: - env.step(None) - else: - action = env.action_space(agent).sample() - env.step(action) - - -def test_terminate_illegal() -> None: - """Test for error in TerminateIllegalWrapper. - - A bug caused TerminateIllegalWrapper to set values on the wrapper - rather than the environment. This tests for a recurrence of that - bug. - """ - # not using env() because we need to ensure that the env is - # wrapped by TerminateIllegalWrapper - raw_env = tictactoe_v3.raw_env() - env = TerminateIllegalWrapper(raw_env, illegal_reward=-1) - - _do_game(env, 42) - # bug is triggered by a corrupted state after a game is terminated - # due to an illegal move. So we need to run the game twice to - # see the effect. - _do_game(env, 42) - - # get a list of what all the agent_selection values in the wrapper stack - unwrapped = env - agent_selections = [] - while unwrapped != env.unwrapped: - # the actual value for this wrapper (or None if no value) - agent_selections.append(unwrapped.__dict__.get("agent_selection", None)) - assert isinstance(unwrapped, BaseWrapper) - unwrapped = unwrapped.env - - # last one from the actual env - agent_selections.append(unwrapped.__dict__.get("agent_selection", None)) - - # remove None from agent_selections - agent_selections = [x for x in agent_selections if x is not None] - - # all values must be the same, or else the wrapper and env are mismatched - assert len(set(agent_selections)) == 1, "agent_selection mismatch" diff --git a/test/wrapper_test.py b/test/wrapper_test.py index 650fe328b..a03bd81b3 100644 --- a/test/wrapper_test.py +++ b/test/wrapper_test.py @@ -3,8 +3,13 @@ import pytest from pettingzoo.butterfly import pistonball_v6 -from pettingzoo.classic import texas_holdem_no_limit_v6 -from pettingzoo.utils.wrappers import MultiEpisodeEnv, MultiEpisodeParallelEnv +from pettingzoo.classic import texas_holdem_no_limit_v6, tictactoe_v3 +from pettingzoo.utils.wrappers import ( + BaseWrapper, + MultiEpisodeEnv, + MultiEpisodeParallelEnv, + TerminateIllegalWrapper, +) @pytest.mark.parametrize(("num_episodes"), [1, 2, 3, 4, 5, 6]) @@ -67,3 +72,65 @@ def test_multi_episode_parallel_env_wrapper(num_episodes) -> None: assert ( steps == num_episodes * 125 ), f"Expected to have 125 steps per episode, got {steps / num_episodes}." + + +def _do_game(env: TerminateIllegalWrapper, seed: int) -> None: + """Run a single game with reproducible random moves.""" + assert isinstance( + env, TerminateIllegalWrapper + ), "test_terminate_illegal must use TerminateIllegalWrapper" + env.reset(seed) + for agent in env.agents: + # make the random moves reproducible + env.action_space(agent).seed(seed) + + for agent in env.agent_iter(): + _, _, termination, truncation, _ = env.last() + + if termination or truncation: + env.step(None) + else: + action = env.action_space(agent).sample() + env.step(action) + + +def test_terminate_illegal() -> None: + """Test for a problem with terminate illegal wrapper. + + The problem is that env variables, including agent_selection, are set by + calls from TerminateIllegalWrapper to env functions. However, they are + called by the wrapper object, not the env so they are set in the wrapper + object rather than the base env object. When the code later tries to run, + the values get updated in the env code, but the wrapper pulls it's own + values that shadow them. + + The test here confirms that is fixed. + """ + # not using env() because we need to ensure that the env is + # wrapped by TerminateIllegalWrapper + raw_env = tictactoe_v3.raw_env() + env = TerminateIllegalWrapper(raw_env, illegal_reward=-1) + + _do_game(env, 42) + # bug is triggered by a corrupted state after a game is terminated + # due to an illegal move. So we need to run the game twice to + # see the effect. + _do_game(env, 42) + + # get a list of what all the agent_selection values in the wrapper stack + unwrapped = env + agent_selections = [] + while unwrapped != env.unwrapped: + # the actual value for this wrapper (or None if no value) + agent_selections.append(unwrapped.__dict__.get("agent_selection", None)) + assert isinstance(unwrapped, BaseWrapper) + unwrapped = unwrapped.env + + # last one from the actual env + agent_selections.append(unwrapped.__dict__.get("agent_selection", None)) + + # remove None from agent_selections + agent_selections = [x for x in agent_selections if x is not None] + + # all values must be the same, or else the wrapper and env are mismatched + assert len(set(agent_selections)) == 1, "agent_selection mismatch"