Skip to content

Commit

Permalink
Correct obs type in parallel_rps doc example (#1170)
Browse files Browse the repository at this point in the history
  • Loading branch information
dm-ackerman authored Feb 14, 2024
1 parent 6c8e8c1 commit 3f8f1be
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions docs/code_examples/parallel_rps.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools

import gymnasium
import numpy as np
from gymnasium.spaces import Discrete

from pettingzoo import ParallelEnv
Expand All @@ -9,7 +10,7 @@
ROCK = 0
PAPER = 1
SCISSORS = 2
NONE = 3
NO_MOVE = 3
MOVES = ["ROCK", "PAPER", "SCISSORS", "None"]
NUM_ITERS = 100
REWARD_MAP = {
Expand Down Expand Up @@ -83,6 +84,7 @@ def __init__(self, render_mode=None):
@functools.lru_cache(maxsize=None)
def observation_space(self, agent):
# gymnasium spaces are defined and documented here: https://gymnasium.farama.org/api/spaces/
# Discrete(4) means an integer in range(0, 4)
return Discrete(4)

# Action space should be defined here.
Expand Down Expand Up @@ -128,7 +130,8 @@ def reset(self, seed=None, options=None):
"""
self.agents = self.possible_agents[:]
self.num_moves = 0
observations = {agent: NONE for agent in self.agents}
# the observations should be numpy arrays even if there is only one value
observations = {agent: np.array(NO_MOVE) for agent in self.agents}
infos = {agent: {} for agent in self.agents}
self.state = observations

Expand Down Expand Up @@ -161,9 +164,11 @@ def step(self, actions):
env_truncation = self.num_moves >= NUM_ITERS
truncations = {agent: env_truncation for agent in self.agents}

# current observation is just the other player's most recent action
# Current observation is just the other player's most recent action
# This is converted to a numpy value of type int to match the type
# that we declared in observation_space()
observations = {
self.agents[i]: int(actions[self.agents[1 - i]])
self.agents[i]: np.array(actions[self.agents[1 - i]], dtype=np.int64)
for i in range(len(self.agents))
}
self.state = observations
Expand Down

0 comments on commit 3f8f1be

Please sign in to comment.