From 3f8f1bee8513581fa7d58f0e92a418b4b5c1532b Mon Sep 17 00:00:00 2001 From: David Ackerman <145808634+dm-ackerman@users.noreply.github.com> Date: Wed, 14 Feb 2024 10:34:12 -0500 Subject: [PATCH] Correct obs type in parallel_rps doc example (#1170) --- docs/code_examples/parallel_rps.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/docs/code_examples/parallel_rps.py b/docs/code_examples/parallel_rps.py index 383659666..bf634af10 100644 --- a/docs/code_examples/parallel_rps.py +++ b/docs/code_examples/parallel_rps.py @@ -1,6 +1,7 @@ import functools import gymnasium +import numpy as np from gymnasium.spaces import Discrete from pettingzoo import ParallelEnv @@ -9,7 +10,7 @@ ROCK = 0 PAPER = 1 SCISSORS = 2 -NONE = 3 +NO_MOVE = 3 MOVES = ["ROCK", "PAPER", "SCISSORS", "None"] NUM_ITERS = 100 REWARD_MAP = { @@ -83,6 +84,7 @@ def __init__(self, render_mode=None): @functools.lru_cache(maxsize=None) def observation_space(self, agent): # gymnasium spaces are defined and documented here: https://gymnasium.farama.org/api/spaces/ + # Discrete(4) means an integer in range(0, 4) return Discrete(4) # Action space should be defined here. @@ -128,7 +130,8 @@ def reset(self, seed=None, options=None): """ self.agents = self.possible_agents[:] self.num_moves = 0 - observations = {agent: NONE for agent in self.agents} + # the observations should be numpy arrays even if there is only one value + observations = {agent: np.array(NO_MOVE) for agent in self.agents} infos = {agent: {} for agent in self.agents} self.state = observations @@ -161,9 +164,11 @@ def step(self, actions): env_truncation = self.num_moves >= NUM_ITERS truncations = {agent: env_truncation for agent in self.agents} - # current observation is just the other player's most recent action + # Current observation is just the other player's most recent action + # This is converted to a numpy value of type int to match the type + # that we declared in observation_space() observations = { - self.agents[i]: int(actions[self.agents[1 - i]]) + self.agents[i]: np.array(actions[self.agents[1 - i]], dtype=np.int64) for i in range(len(self.agents)) } self.state = observations