Skip to content

Commit

Permalink
ready for test
Browse files Browse the repository at this point in the history
  • Loading branch information
mttga committed Mar 15, 2024
1 parent a3638b1 commit bed7b60
Showing 1 changed file with 62 additions and 21 deletions.
83 changes: 62 additions & 21 deletions jaxmarl/environments/hanabi/manual_game.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
import jax
from jax import numpy as jnp
jax.config.update('jax_platform_name', 'cpu') # force playing on cpu
#jax.config.update('jax_platform_name', 'cpu') # force playing on cpu
from jaxmarl import make
import random
import pprint
import sys
import numpy as np
import argparse
from obl.obl_pytorch import OBLPytorchAgent
import json

OBL1A_WEIGHT = "obl/models/icml_OBL1/OFF_BELIEF1_SHUFFLE_COLOR0_BZA0_BELIEF_a/model0.pthw"
OBL1A_WEIGHT_TORCH = "/app/JaxMARL/jaxmarl/environments/hanabi/obl/models/torch_models/icml_OBL1/OFF_BELIEF1_SHUFFLE_COLOR0_BZA0_BELIEF_a/model0.pthw"
#OBL1A_WEIGHT_FLAX = "obl/models/flax_models/icml_OBL1/OFF_BELIEF1_SHUFFLE_COLOR0_BZA0_BELIEF_a.safetensors"
OBL1A_WEIGHT_FLAX = "/app/JaxMARL/jaxmarl/environments/hanabi/obl/models/flax_models/icml_OBL1/OFF_BELIEF1_SHUFFLE_COLOR0_BZA0_BELIEF_a.safetensors"

with open('decks_test.json') as f:
decks_j = json.load(f)

decks_test_rngs = jnp.array([jnp.array(np.array(deck['jax_rng'], dtype=np.uint32))for deck in decks_j])

with open('cpp_deck_actions.json', 'r') as file:
deck_actions = json.load(file)

class ManualPlayer:
def __init__(self, player_idx):
Expand Down Expand Up @@ -61,11 +70,18 @@ def get_agents(args):
weight_file = args.weight
else:
weight_file = getattr(args, f"weight{player_idx}")
if weight_file is None:
weight_file = OBL1A_WEIGHT

if player_type == "manual":
agents.append(ManualPlayer(player_idx))
elif player_type == 'obl_flax':
from obl.obl_flax import OBLFlaxAgent
if weight_file is None:
weight_file = OBL1A_WEIGHT_FLAX
agents.append(OBLFlaxAgent(weight_file, player_idx))
elif player_type == "obl":
from obl.obl_pytorch import OBLPytorchAgent
if weight_file is None:
weight_file = OBL1A_WEIGHT_TORCH
agents.append(OBLPytorchAgent(weight_file))

return agents
Expand All @@ -79,45 +95,69 @@ def play_game(args, action_encoding):
print(f"{'-'*10}\nStarting new game with random seed: {seed}\n")

agents = get_agents(args)

with jax.disable_jit():
env = make('hanabi', debug=False)

if args.use_jit is not None:
use_jit = args.use_jit
else:
use_jit = True

with jax.disable_jit(not use_jit):
env = make('hanabi')
rng = jax.random.PRNGKey(seed)
rng, _rng = jax.random.split(rng)
obs, state = env.reset(_rng)
legal_moves = env.get_legal_moves(state)

# custom seed from the deck test
_rng = decks_test_rngs[seed]

# custom actions
pre_actions = np.array(deck_actions[seed]).astype(int)

obs, env_state = env.reset(_rng)
legal_moves = env.get_legal_moves(env_state)

@jax.jit
def _step_env(rng, env_state, actions):
rng, _rng = jax.random.split(rng)
new_obs, new_env_state, reward, dones, infos = env.step(_rng, env_state, actions)
new_legal_moves = env.get_legal_moves(new_env_state)
return rng, new_env_state, new_obs, reward, dones, new_legal_moves

done = False
cum_rew = 0
t = 0

print("\n" + "=" * 40 + "\n")

while not done:
env.render(state)
print()
env.render(env_state)

curr_player = np.where(state.cur_player_idx==1)[0][0]
curr_player = np.where(env_state.cur_player_idx==1)[0][0]
actions_all = [
agents[i].act(env, obs, legal_moves, curr_player)
agents[i].act(obs, legal_moves, curr_player)
for i in range(len(env.agents))
]

actions = actions_all[curr_player]
#actions = np.zeros(2,dtype=int)
#actions[curr_player] = pre_actions[t]

played_action = (actions[curr_player] - 1) % 20
played_action = actions[curr_player]
print("played action:", played_action)
print(f"Move played: {action_encoding[played_action]} ({played_action})")

actions = {agent:jnp.array([actions[i]]) for i, agent in enumerate(env.agents)}
rng, _rng = jax.random.split(rng)
obs, state, reward, dones, infos = env.step(_rng, state, actions)
legal_moves = env.get_legal_moves(state)
actions = {agent:jnp.array(actions[i]) for i, agent in enumerate(env.agents)}

rng, env_state, obs, reward, dones, legal_moves = _step_env(rng, env_state, actions)

done = dones['__all__']
cum_rew += reward['__all__']
t += 1

print("\n" + "=" * 40 + "\n")



print('Game Ended.')
print('Game Ended. Score:', cum_rew)


def main(args):
Expand Down Expand Up @@ -151,7 +191,7 @@ def main(args):
play_game(args, action_encoding)

new_game = 'y'
while new_game=='y':
while False and new_game=='y':
new_game = input('New Game?')
if new_game=='y':
play_game(args, action_encoding)
Expand All @@ -165,6 +205,7 @@ def parse_args():
parser.add_argument("--weight0", type=str, default=None)
parser.add_argument("--weight1", type=str, default=None)
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--use_jit", type=bool, default=True)
args = parser.parse_args()
print(args)
return args
Expand Down

0 comments on commit bed7b60

Please sign in to comment.