Skip to content

Commit

Permalink
agent id wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
amacrutherford committed Mar 21, 2024
1 parent 351aed4 commit 6ec2e1a
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 7 deletions.
20 changes: 13 additions & 7 deletions jaxmarl/environments/spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,21 @@
import jax.numpy as jnp

class Space(object):
"""
Minimal jittable class for abstract jaxmarl space.
"""
"""
Minimal jittable class for abstract jaxmarl space.
"""

def sample(self, rng: chex.PRNGKey) -> chex.Array:
raise NotImplementedError
def sample(self, rng: chex.PRNGKey) -> chex.Array:
raise NotImplementedError

def contains(self, x: jnp.int_) -> bool:
raise NotImplementedError
def contains(self, x: jnp.int_) -> bool:
raise NotImplementedError

def replace(self, **kwargs):
"""Replace the current space with new values."""
for k, v in kwargs.items():
setattr(self, k, v)
return self

class Discrete(Space):
"""
Expand Down
31 changes: 31 additions & 0 deletions jaxmarl/wrappers/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,37 @@ def _batchify_floats(self, x: dict):
return jnp.stack([x[a] for a in self._env.agents])


class AddAgentID(JaxMARLWrapper):
""" Add one hot encoded agent id to start all agent's observations."""

def __init__(self, env: MultiAgentEnv):
super().__init__(env)
self.agent_ids = jnp.eye(self._env.num_agents)

# Update observation space
self._env.observation_spaces = jax.tree_map(lambda s: s.replace(shape=(s.shape[0]+self._env.num_agents,)), self._env.observation_spaces)

def reset(self, key: chex.PRNGKey) -> Tuple[chex.Array, State]:
obs, state = self._env.reset(key)
obs_batch = self._batchify_floats(obs)
obs = jax.tree_util.tree_map(lambda x: jnp.concatenate([x, self.agent_ids], axis=-1), obs_batch)
obs = {a: obs[i] for i, a in enumerate(self._env.agents)}
return obs, state

def step(
self,
key: chex.PRNGKey,
state: State,
action: Union[int, float],
) -> Tuple[chex.Array, State, float, bool, dict]:

obs, state, reward, done, info = self._env.step(key, state, action)
obs_batch = self._batchify_floats(obs)
obs = jax.tree_util.tree_map(lambda x: jnp.concatenate([x, self.agent_ids], axis=-1), obs_batch)
obs = {a: obs[i] for i, a in enumerate(self._env.agents)}
return obs, state, reward, done, info


@struct.dataclass
class LogEnvState:
env_state: State
Expand Down

0 comments on commit 6ec2e1a

Please sign in to comment.