Skip to content

Commit

Permalink
Merge pull request #63 from FLAIROx/hanabi_obl_pytorch
Browse files Browse the repository at this point in the history
Hanabi obl pytorch
  • Loading branch information
mttga authored Mar 5, 2024
2 parents a0e312c + c027564 commit aae9f99
Show file tree
Hide file tree
Showing 8 changed files with 1,337 additions and 54 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ docs/
tmp/
*-checkpoint.py
wandb/
outputs/
outputs/
models/
3 changes: 3 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ RUN pip install --ignore-installed -e '.[qlearning, dev]'
# install jax from to enable cuda
RUN pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

RUN pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu118


#disabling preallocation
RUN export XLA_PYTHON_CLIENT_PREALLOCATE=false
#safety measures
Expand Down
4 changes: 2 additions & 2 deletions jaxmarl/environments/hanabi/hanabi.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def _hint_fn(state, action):
game_won = (fireworks_after == (self.num_colors * self.num_ranks))
deck_empty = (state.num_cards_dealt >= self.deck_size)
last_round_count = state.last_round_count + deck_empty
last_round_done = (last_round_count == self.num_agents)
last_round_done = (last_round_count == self.num_agents + 1)
terminal = jnp.logical_or(jnp.logical_or(state.out_of_lives, game_won), last_round_done)

# last moves
Expand Down Expand Up @@ -585,4 +585,4 @@ def observation_space(self, agent: str):

def action_space(self, agent: str):
""" Action space for a given agent."""
return self.action_spaces[agent]
return self.action_spaces[agent]
6 changes: 0 additions & 6 deletions jaxmarl/environments/hanabi/hanabi_obl.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def __init__(
num_moves=None,
debug=False
):

super().__init__(
num_agents=num_agents,
num_colors=num_colors,
Expand Down Expand Up @@ -439,9 +438,4 @@ def get_card_knowledge_str(card_idx:int)->str:
print(f'Actor {aidx} Hand:' + ('<-- current player' if aidx==current_player else ''))
for card_str in get_actor_hand_str(aidx):
print(card_str)

print('---')




181 changes: 136 additions & 45 deletions jaxmarl/environments/hanabi/manual_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,83 +5,174 @@
import random
import pprint
import sys
import numpy as np
import argparse
from obl.obl_pytorch import OBLPytorchAgent

def play_game():
OBL1A_WEIGHT = "obl/models/icml_OBL1/OFF_BELIEF1_SHUFFLE_COLOR0_BZA0_BELIEF_a/model0.pthw"

seed = random.randint(0, 10000)

class ManualPlayer:
def __init__(self, player_idx):
self._player_idx = player_idx

def act(self, env, obs, legal_moves, curr_player) -> int:
legal_moves = batchify(env, legal_moves)
legal_moves = jnp.roll(legal_moves, -1, axis=1)

actions = np.array([0, 0])

if curr_player != self._player_idx:
return actions

print("Legal moves:")
print(legal_moves[curr_player])

# take action input from user
while True:
try:
print("---")
action = int(input('Insert manual action: '))
print("action legal:", legal_moves[curr_player][action])
print("---\n")
if action >= 0 and action <= 20 and legal_moves[curr_player][action] == 1:
break
else:
print('Invalid action.')
except KeyboardInterrupt:
sys.exit(0)
except:
action = 0
print('Invalid action.')

action = (action + 1) % 21

actions[curr_player] = action

return actions


def get_agents(args):
agents = []

for player_idx in [0, 1]:
player_type = getattr(args, f"player{player_idx}")
if args.weight is not None:
weight_file = args.weight
else:
weight_file = getattr(args, f"weight{player_idx}")
if weight_file is None:
weight_file = OBL1A_WEIGHT
if player_type == "manual":
agents.append(ManualPlayer(player_idx))
elif player_type == "obl":
agents.append(OBLPytorchAgent(weight_file))

return agents


def play_game(args, action_encoding):
if args.seed is not None:
seed = args.seed
else:
seed = random.randint(0, 10000)
print(f"{'-'*10}\nStarting new game with random seed: {seed}\n")

agents = get_agents(args)

with jax.disable_jit():
env = make('hanabi', debug=True)
env = make('hanabi', debug=False)
rng = jax.random.PRNGKey(seed)
rng, _rng = jax.random.split(rng)
obs, state = env.reset(_rng)
legal_moves = env.get_legal_moves(state)

env.render(state)

done = False

print("\n" + "=" * 40 + "\n")

while not done:
env.render(state)
print()

# take action input from user
while True:
try:
action = int(input('Insert next action: '))
if action>=0&action<=20:
break
else:
print('Invalid action.')
except KeyboardInterrupt:
sys.exit(0)
except:
print('Invalid action.')
curr_player = np.where(state.cur_player_idx==1)[0][0]
actions_all = [
agents[i].act(env, obs, legal_moves, curr_player)
for i in range(len(env.agents))
]

actions = actions_all[curr_player]

played_action = (actions[curr_player] - 1) % 20
print("played action:", played_action)
print(f"Move played: {action_encoding[played_action]} ({played_action})")

actions = {agent:jnp.array([action]) for agent in env.agents}
actions = {agent:jnp.array([actions[i]]) for i, agent in enumerate(env.agents)}
rng, _rng = jax.random.split(rng)
obs, state, reward, dones, infos = env.step(_rng, state, actions)
legal_moves = env.get_legal_moves(state)
done = dones['__all__']
env.render(state, debug=False)

print("\n" + "=" * 40 + "\n")



print('Game Ended.')

def main():

def main(args):
action_encoding = {
"Discard 0": 1,
"Discard 1": 2,
"Discard 2": 3,
"Discard 3": 4,
"Discard 4": 5,
"Play 0": 6,
"Play 1": 7,
"Play 2": 8,
"Play 3": 9,
"Play 4": 10,
"Reveal player +1 color R": 11,
"Reveal player +1 color Y": 12,
"Reveal player +1 color G": 13,
"Reveal player +1 color W": 14,
"Reveal player +1 color B": 15,
"Reveal player +1 rank 1": 16,
"Reveal player +1 rank 2": 17,
"Reveal player +1 rank 3": 18,
"Reveal player +1 rank 4": 19,
"Reveal player +1 rank 5": 20,
"INVALID": 0
0: "Discard 0",
1: "Discard 1",
2: "Discard 2",
3: "Discard 3",
4: "Discard 4",
5: "Play 0",
6: "Play 1",
7: "Play 2",
8: "Play 3",
9: "Play 4",
10: "Reveal player +1 color R",
11: "Reveal player +1 color Y",
12: "Reveal player +1 color G",
13: "Reveal player +1 color W",
14: "Reveal player +1 color B",
15: "Reveal player +1 rank 1",
16: "Reveal player +1 rank 2",
17: "Reveal player +1 rank 3",
18: "Reveal player +1 rank 4",
19: "Reveal player +1 rank 5",
20: "INVALID",
}


print('Starting Hanabi. Remember, actions encoding is:')
pprint.pprint(action_encoding)

play_game()
play_game(args, action_encoding)

new_game = 'y'
while new_game=='y':
new_game = input('New Game?')
if new_game=='y':
play_game()
play_game(args, action_encoding)


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--player0", type=str, default="obl")
parser.add_argument("--player1", type=str, default="manual")
parser.add_argument("--weight", type=str, default=None)
parser.add_argument("--weight0", type=str, default=None)
parser.add_argument("--weight1", type=str, default=None)
parser.add_argument("--seed", type=int, default=None)
args = parser.parse_args()
print(args)
return args

def batchify(env, x):
return jnp.stack([x[a] for a in env.agents])


if __name__=='__main__':
main()
args = parse_args()
main(args)
Loading

0 comments on commit aae9f99

Please sign in to comment.