Skip to content

Commit

Permalink
Merge pull request #60 from hy-kiera/fix_storm_done
Browse files Browse the repository at this point in the history
Fixing incomplete state update issue in the step function of STORM environemnt
  • Loading branch information
Aidandos authored Mar 13, 2024
2 parents 04470f1 + 051985a commit cc9f12b
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 37 deletions.
20 changes: 11 additions & 9 deletions jaxmarl/environments/storm/storm_2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,17 +751,19 @@ def _step(
inner_t = state_nxt.inner_t
outer_t = state_nxt.outer_t
reset_inner = inner_t == num_inner_steps
done = {}
done["__all__"] = reset_inner = inner_t == num_inner_steps

# if inner episode is done, return start state for next game
# state_re = _reset_state(key)
# state_re = state_re.replace(outer_t=outer_t + 1)
# state = jax.tree_map(
# lambda x, y: jax.lax.select(reset_inner, x, y),
# state_re,
# state_nxt,
# )
state_re = _reset_state(key)
state_re = state_re.replace(outer_t=outer_t + 1)
state = jax.tree_map(
lambda x, y: jax.lax.select(reset_inner, x, y),
state_re,
state_nxt,
)
outer_t = state.outer_t
reset_outer = outer_t == num_outer_steps
done = {}
done["__all__"] = reset_outer

obs = _get_obs(state)
blue_reward = jnp.where(reset_inner, 0, blue_reward)
Expand Down
22 changes: 12 additions & 10 deletions jaxmarl/environments/storm/storm_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,18 +886,20 @@ def update_timers(coop_coin_timer, defect_coin_timer, new_coop_coin_timer, new_d
# now calculate if done for inner or outer episode
inner_t = state_nxt.inner_t
outer_t = state_nxt.outer_t
done = {}
done["__all__"] = reset_inner = inner_t == num_inner_steps

reset_inner = inner_t == num_inner_steps

# # # if inner episode is done, return start state for next game
# state_re = _reset_state(key)
# state_re = state_re.replace(outer_t=outer_t + 1)
# state = jax.tree_map(
# lambda x, y: jax.lax.select(reset_inner, x, y),
# state_re,
# state_nxt,
# )
state_re = _reset_state(key)
state_re = state_re.replace(outer_t=outer_t + 1)
state = jax.tree_map(
lambda x, y: jax.lax.select(reset_inner, x, y),
state_re,
state_nxt,
)
outer_t = state.outer_t
reset_outer = outer_t == num_outer_steps
done = {}
done["__all__"] = reset_outer

obs = _get_obs(state)
rewards = jnp.where(reset_inner, 0, rewards)
Expand Down
5 changes: 4 additions & 1 deletion jaxmarl/tutorials/storm_2p_introduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

action = 1
render_agent_view = True
num_outer_steps = 1
num_outer_steps = 3
num_inner_steps = 152

rng = jax.random.PRNGKey(0)
Expand Down Expand Up @@ -58,6 +58,9 @@
rng, old_state, (a1 * action, a2 * action)
)

print('outer t', state.outer_t)
print('inner t', state.inner_t)
print('done', done)
if (state.red_pos[:2] == state.blue_pos[:2]).all():
import pdb

Expand Down
21 changes: 4 additions & 17 deletions jaxmarl/tutorials/storm_introduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

action=1
render_agent_view = False
num_outer_steps = 1
num_outer_steps = 3
# num_inner_steps = 68
#num_agents=8
num_agents=2
Expand Down Expand Up @@ -96,22 +96,9 @@ def pos_add(i, val):
obs, state, reward, done, info = env.step_env(
rng, old_state, [a*action for a in actions]
)
# print(state.agent_inventories, 'agent inventories')
# print(actions.shape, 'actions')
# print(state.agent_positions.shape, 'agent positions')
# print(state.agent_freezes.shape, 'agent freezes')
# if (state.grid == old_state.grid).all():
# print(t, 'hello there')

# if (state.red_pos[:2] == state.blue_pos[:2]).all():
# import pdb

# # pdb.set_trace()
# print("collision")
# print(
# f"timestep: {t}, A1: {int_action[a1.item()]} A2:{int_action[a2.item()]}"
# )
# print(state.red_pos, state.blue_pos)
print('outer t', state.outer_t)
print('inner t', state.inner_t)
print('done', done)

img = env.render(state)
Image.fromarray(img).save(f"state_pics/state_{t+1}.png")
Expand Down

0 comments on commit cc9f12b

Please sign in to comment.