Skip to content

Commit

Permalink
Merge pull request #69 from FLAIROx/smax-remove-unit-dependence
Browse files Browse the repository at this point in the history
Add different observation and action space
  • Loading branch information
amacrutherford authored Mar 20, 2024
2 parents cc9f12b + 32a650c commit 60b9fd2
Show file tree
Hide file tree
Showing 4 changed files with 747 additions and 173 deletions.
104 changes: 71 additions & 33 deletions jaxmarl/environments/smax/heuristic_enemy_smax_env.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import dataclasses
from jaxmarl.environments.smax.smax_env import SMAX
from jaxmarl.environments.smax.smax_env import State as SMAXState
from jaxmarl.environments.smax.heuristic_enemy import (
create_heuristic_policy,
get_heuristic_policy_initial_state,
Expand Down Expand Up @@ -59,30 +61,74 @@ def get_enemy_actions(self, key, enemy_policy_state, enemy_obs):
def get_enemy_policy_initial_state(self, key):
raise NotImplementedError

@partial(jax.jit, static_argnums=(0,))
def step_env(self, key: chex.PRNGKey, state: State, actions: Dict[str, chex.Array]):
@partial(jax.jit, static_argnums=(0, 4))
def step_env(
self,
key: chex.PRNGKey,
state: State,
actions: Dict[str, chex.Array],
get_state_sequence=False,
):
jaxmarl_state = state.state
obs = self._env.get_obs(jaxmarl_state)
enemy_obs = jnp.array([obs[agent] for agent in self.enemy_agents])
enemy_obs = self._env.get_obs_unit_list(jaxmarl_state)
enemy_obs = jnp.array([enemy_obs[agent] for agent in self.enemy_agents])
key, action_key = jax.random.split(key)
enemy_actions, enemy_policy_state = self.get_enemy_actions(
action_key, state.enemy_policy_state, enemy_obs
)
enemy_actions = jnp.array([enemy_actions[i] for i in self.enemy_agents])
actions = jnp.array([actions[i] for i in self.agents])
enemy_movement_actions, enemy_attack_actions = (
self._env._decode_discrete_actions(enemy_actions)
)
if self._env.action_type == "continuous":
cont_actions = jnp.zeros((len(self.all_agents), 4))
cont_actions = cont_actions.at[: self.num_allies].set(actions)
key, action_key = jax.random.split(key)
ally_movement_actions, ally_attack_actions = (
self._env._decode_continuous_actions(
action_key, jaxmarl_state, cont_actions
)
)
ally_movement_actions = ally_movement_actions[: self.num_allies]
ally_attack_actions = ally_attack_actions[: self.num_allies]
else:
ally_movement_actions, ally_attack_actions = (
self._env._decode_discrete_actions(actions)
)

actions = {k: v.squeeze() for k, v in actions.items()}
actions = {**enemy_actions, **actions}
obs, jaxmarl_state, rewards, dones, infos = self._env.step_env(
key, jaxmarl_state, actions
movement_actions = jnp.concat(
[ally_movement_actions, enemy_movement_actions], axis=0
)
new_obs = {agent: obs[agent] for agent in self.agents}
new_obs["world_state"] = obs["world_state"]
rewards = {agent: rewards[agent] for agent in self.agents}
all_done = dones["__all__"]
dones = {agent: dones[agent] for agent in self.agents}
dones["__all__"] = all_done
attack_actions = jnp.concat([ally_attack_actions, enemy_attack_actions], axis=0)

if not get_state_sequence:
obs, jaxmarl_state, rewards, dones, infos = self._env.step_env_no_decode(
key,
jaxmarl_state,
(movement_actions, attack_actions),
get_state_sequence=get_state_sequence,
)
new_obs = {agent: obs[agent] for agent in self.agents}
new_obs["world_state"] = obs["world_state"]
rewards = {agent: rewards[agent] for agent in self.agents}
all_done = dones["__all__"]
dones = {agent: dones[agent] for agent in self.agents}
dones["__all__"] = all_done

state = state.replace(enemy_policy_state=enemy_policy_state, state=jaxmarl_state)
return new_obs, state, rewards, dones, infos
state = state.replace(
enemy_policy_state=enemy_policy_state, state=jaxmarl_state
)
return new_obs, state, rewards, dones, infos
else:
states = self._env.step_env_no_decode(
key,
jaxmarl_state,
(movement_actions, attack_actions),
get_state_sequence=get_state_sequence,
)
return states

@partial(jax.jit, static_argnums=(0,))
def get_avail_actions(self, state: State):
Expand Down Expand Up @@ -110,28 +156,20 @@ def expand_state_seq(self, state_seq):
# it's not exposed to the user so we can't ask them to store it. Not a problem
# for now but will have to get creative in the future potentially.
for key, state, actions in state_seq:
agents = self.all_agents
# There is a split in the step function of MultiAgentEnv
# We call split here so that the action key is the same.
key, _ = jax.random.split(key)
key, key_action = jax.random.split(key)
obs = self.get_all_unit_obs(state)
obs = jnp.array([obs[agent] for agent in self.enemy_agents])
enemy_actions, _ = self.get_enemy_actions(
key_action, state.enemy_policy_state, obs
states = self.step_env(key, state, actions, get_state_sequence=True)
states = list(map(SMAXState, *dataclasses.astuple(states)))
viz_actions = {
agent: states[-1].prev_attack_actions[i]
for i, agent in enumerate(self.all_agents)
}

expanded_state_seq.append((key, state.state, viz_actions))
expanded_state_seq.extend(
zip([key] * len(states), states, [viz_actions] * len(states))
)
actions = {k: v.squeeze() for k, v in actions.items()}
actions = {**enemy_actions, **actions}
for _ in range(self.world_steps_per_env_step):
expanded_state_seq.append((key, state.state, actions))
world_actions = jnp.array([actions[i] for i in agents])
key, step_key = jax.random.split(key)
_state = state.state
_state = self._env._world_step(step_key, _state, world_actions)
_state = self._env._kill_agents_touching_walls(_state)
_state = self._env._update_dead_agents(_state)
_state = self._env._push_units_away(_state)
state = state.replace(state=_state)
state = state.replace(
state=state.state.replace(terminal=self.is_terminal(state))
)
Expand Down
Loading

0 comments on commit 60b9fd2

Please sign in to comment.