Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TerminateIllegalWrapper fix #1206

Merged
merged 6 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions pettingzoo/utils/wrappers/terminate_illegal.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# pyright reportGeneralTypeIssues=false
from __future__ import annotations

from pettingzoo.utils.env import ActionType, AECEnv, AgentID, ObsType
Expand All @@ -20,6 +19,7 @@ def __init__(
self._illegal_value = illegal_reward
self._prev_obs = None
self._prev_info = None
self._terminated = False # terminated by an illegal move

def reset(self, seed: int | None = None, options: dict | None = None) -> None:
self._terminated = False
Expand All @@ -42,7 +42,6 @@ def step(self, action: ActionType) -> None:
if self._prev_obs is None:
self.observe(self.agent_selection)
if isinstance(self._prev_obs, dict):
assert self._prev_obs is not None
assert (
jjshoots marked this conversation as resolved.
Show resolved Hide resolved
"action_mask" in self._prev_obs
), f"`action_mask` not found in dictionary observation: {self._prev_obs}. Action mask must either be in `observation['action_mask']` or `info['action_mask']` to use TerminateIllegalWrapper."
Expand All @@ -60,7 +59,7 @@ def step(self, action: ActionType) -> None:
self.terminations[self.agent_selection]
or self.truncations[self.agent_selection]
):
self._was_dead_step(action) # pyright: ignore[reportGeneralTypeIssues]
self.env.unwrapped._was_dead_step(action)
elif (
not self.terminations[self.agent_selection]
and not self.truncations[self.agent_selection]
Expand All @@ -70,12 +69,10 @@ def step(self, action: ActionType) -> None:
self.env.unwrapped._cumulative_rewards[self.agent_selection] = 0
self.env.unwrapped.terminations = {d: True for d in self.agents}
self.env.unwrapped.truncations = {d: True for d in self.agents}
self._prev_obs = None
self._prev_info = None
self.env.unwrapped.rewards = {d: 0 for d in self.truncations}
self.env.unwrapped.rewards[current_agent] = float(self._illegal_value)
self._accumulate_rewards()
self._deads_step_first()
self.env.unwrapped._accumulate_rewards()
self.env.unwrapped._deads_step_first()
self._terminated = True
else:
super().step(action)
Expand Down
71 changes: 69 additions & 2 deletions test/wrapper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@
import pytest

from pettingzoo.butterfly import pistonball_v6
from pettingzoo.classic import texas_holdem_no_limit_v6
from pettingzoo.utils.wrappers import MultiEpisodeEnv, MultiEpisodeParallelEnv
from pettingzoo.classic import texas_holdem_no_limit_v6, tictactoe_v3
from pettingzoo.utils.wrappers import (
BaseWrapper,
MultiEpisodeEnv,
MultiEpisodeParallelEnv,
TerminateIllegalWrapper,
)


@pytest.mark.parametrize(("num_episodes"), [1, 2, 3, 4, 5, 6])
Expand Down Expand Up @@ -67,3 +72,65 @@ def test_multi_episode_parallel_env_wrapper(num_episodes) -> None:
assert (
steps == num_episodes * 125
), f"Expected to have 125 steps per episode, got {steps / num_episodes}."


def _do_game(env: TerminateIllegalWrapper, seed: int) -> None:
"""Run a single game with reproducible random moves."""
assert isinstance(
env, TerminateIllegalWrapper
), "test_terminate_illegal must use TerminateIllegalWrapper"
env.reset(seed)
for agent in env.agents:
# make the random moves reproducible
env.action_space(agent).seed(seed)

for agent in env.agent_iter():
_, _, termination, truncation, _ = env.last()

if termination or truncation:
env.step(None)
else:
action = env.action_space(agent).sample()
env.step(action)


def test_terminate_illegal() -> None:
"""Test for a problem with terminate illegal wrapper.

The problem is that env variables, including agent_selection, are set by
calls from TerminateIllegalWrapper to env functions. However, they are
called by the wrapper object, not the env so they are set in the wrapper
object rather than the base env object. When the code later tries to run,
the values get updated in the env code, but the wrapper pulls it's own
values that shadow them.

The test here confirms that is fixed.
"""
# not using env() because we need to ensure that the env is
# wrapped by TerminateIllegalWrapper
raw_env = tictactoe_v3.raw_env()
env = TerminateIllegalWrapper(raw_env, illegal_reward=-1)

_do_game(env, 42)
# bug is triggered by a corrupted state after a game is terminated
# due to an illegal move. So we need to run the game twice to
# see the effect.
_do_game(env, 42)

# get a list of what all the agent_selection values in the wrapper stack
unwrapped = env
agent_selections = []
while unwrapped != env.unwrapped:
# the actual value for this wrapper (or None if no value)
agent_selections.append(unwrapped.__dict__.get("agent_selection", None))
assert isinstance(unwrapped, BaseWrapper)
unwrapped = unwrapped.env

# last one from the actual env
agent_selections.append(unwrapped.__dict__.get("agent_selection", None))

# remove None from agent_selections
agent_selections = [x for x in agent_selections if x is not None]

# all values must be the same, or else the wrapper and env are mismatched
assert len(set(agent_selections)) == 1, "agent_selection mismatch"
Loading