Skip to content
This repository has been archived by the owner on Jul 16, 2024. It is now read-only.

Commit

Permalink
removing agent strategy from the state class
Browse files Browse the repository at this point in the history
  • Loading branch information
big-c-note committed May 18, 2020
1 parent d6b666b commit 3c6bb8b
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 36 deletions.
22 changes: 12 additions & 10 deletions pluribus/games/short_deck/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def __init__(
pickle_dir: str = ".",
load_pickle_files: bool = True,
real_time_test: bool = False,
offline_strategy: Dict = None,
public_cards: List[Card] = []
):
"""Initialise state."""
Expand Down Expand Up @@ -133,8 +132,6 @@ def __init__(
# only want to do these actions in real game play, as they are slow
if self.real_time_test:
# must have offline strategy loaded up
assert offline_strategy
self._offline_strategy = offline_strategy
self._starting_hand_probs = self._initialize_starting_hands()
# TODO: We might not need this
cards_in_deck = self._table.dealer.deck._cards_in_deck
Expand Down Expand Up @@ -337,7 +334,7 @@ def _normalize_bayes(self):
for starting_hand, prob in self._starting_hand_probs[player].items():
self._starting_hand_probs[player][starting_hand] = prob / total_prob

def update_hole_cards_bayes(self):
def _update_hole_cards_bayes(self, offline_strategy: Dict):
"""Get probability of reach for each pair of hole cards for each player"""
n_players = len(self._table.players)
player_indices: List[int] = [p_i for p_i in range(n_players)]
Expand Down Expand Up @@ -402,7 +399,7 @@ def update_hole_cards_bayes(self):
# doesn't work for calculations that need to be made with the object's values

try: # TODO: with or without keys
prob = self._offline_strategy[infoset][action]
prob = offline_strategy[infoset][action]
except KeyError:
prob = 1 / len(self.legal_actions)
prob_reach_all_hands.append(prob)
Expand Down Expand Up @@ -433,7 +430,7 @@ def update_hole_cards_bayes(self):
)
# TODO: Check this
try:
prob = self._offline_strategy[infoset][action]
prob = offline_strategy[infoset][action]
except KeyError:
prob = 1 / len(self.legal_actions)
if "p_reach" not in locals():
Expand All @@ -444,7 +441,6 @@ def update_hole_cards_bayes(self):
self._starting_hand_probs[p_i][tuple(starting_hand)] = p_reach
self._normalize_bayes()
# TODO: delete this? at least for our purposes we don't need it again
del self._offline_strategy

def deal_bayes(self):
start = time.time()
Expand Down Expand Up @@ -481,18 +477,24 @@ def deal_bayes(self):
return new_state
# TODO add convenience method to supply public cards

def get_game_state(self, action_sequence: list):
def load_game_state(self, offline_strategy: Dict, action_sequence: list):
"""
Follow through the action sequence provided to get current node.
:param action_sequence: List of actions without 'skip'
"""
if not action_sequence:
return self
# TODO: not 100 percent sure I need to deep copy
lut = self.info_set_lut
self.info_set_lut = {}
new_state = copy.deepcopy(self)
new_state.info_set_lut = self.info_set_lut = lut
new_state._update_hole_cards_bayes(offline_strategy)
return new_state
a = action_sequence.pop(0)
if a == "skip":
a = action_sequence.pop(0)
new_state = self.apply_action(a)
return new_state.get_game_state(action_sequence)
return new_state.load_game_state(offline_strategy, action_sequence)

def _get_starting_hand(self, player_idx: int):
"""Get starting hand based on probability of reach"""
Expand Down
2 changes: 1 addition & 1 deletion research/test_methodology/RT.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
agent_output = train(
agent1.offline_strategy, public_cards, action_sequence, 40, 6, 6, 3, 2, 6
) # TODO: back to 50
with open("realtime-strategy-moved-agent.pkl", "wb") as file:
with open("realtime-strategy-refactor-game-state.pkl", "wb") as file:
pickle.dump(agent_output, file)
import ipdb

Expand Down
41 changes: 16 additions & 25 deletions research/test_methodology/RT_cfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@

from tqdm import trange

from pluribus import utils
from pluribus.games.short_deck.state import *
from pluribus.games.short_deck.agent import *


def update_strategy(agent: Agent, state: ShortDeckPokerState, ph_test_node: int):
"""
Expand All @@ -37,7 +39,8 @@ def update_strategy(agent: Agent, state: ShortDeckPokerState, ph_test_node: int)
try:
I = state.info_set
except:
import ipdb;
import ipdb

ipdb.set_trace()
# calculate regret
logging.debug(f"About to Calculate Strategy, Regret: {agent.regret[I]}")
Expand All @@ -63,7 +66,7 @@ def update_strategy(agent: Agent, state: ShortDeckPokerState, ph_test_node: int)


def calculate_strategy(
regret: Dict[str, Dict[str, float]], I: str, state: ShortDeckPokerState,
regret: Dict[str, Dict[str, float]], I: str, state: ShortDeckPokerState,
):
"""
Expand Down Expand Up @@ -130,7 +133,8 @@ def cfr(agent: Agent, state: ShortDeckPokerState, i: int, t: int) -> float:
try:
I = state.info_set
except:
import ipdb;
import ipdb

ipdb.set_trace()
# calculate strategy
logging.debug(f"About to Calculate Strategy, Regret: {agent.regret[I]}")
Expand Down Expand Up @@ -165,7 +169,8 @@ def cfr(agent: Agent, state: ShortDeckPokerState, i: int, t: int) -> float:
try:
Iph = state.info_set
except:
import ipdb;
import ipdb

ipdb.set_trace()
logging.debug(f"About to Calculate Strategy, Regret: {agent.regret[Iph]}")
logging.debug(f"Current regret: {agent.regret[Iph]}")
Expand All @@ -188,29 +193,10 @@ def cfr(agent: Agent, state: ShortDeckPokerState, i: int, t: int) -> float:
new_state: ShortDeckPokerState = state.apply_action(a)
return cfr(agent, new_state, i, t)

# added some flags for RT
def new_rt_game(
n_players: int, offline_strategy: Dict, action_sequence, public_cards=[], real_time_test=True
) -> ShortDeckPokerState:
"""Create a new game of short deck poker."""
pot = Pot()
players = [
ShortDeckPokerPlayer(player_i=player_i, initial_chips=10000, pot=pot)
for player_i in range(n_players)
]
state = ShortDeckPokerState(
players=players, offline_strategy=offline_strategy, real_time_test=real_time_test, public_cards=public_cards
)
current_game_state = state.get_game_state(action_sequence)
# decided to make this a one time method rather than something that always updates
# reason being: we won't need it except for a few choice nodes
current_game_state.update_hole_cards_bayes()
return current_game_state


def train(
offline_strategy: Dict,
public_cards,
public_cards: list,
action_sequence: list,
n_iterations: int,
lcfr_threshold: int,
Expand All @@ -223,7 +209,12 @@ def train(
utils.random.seed(36)
agent = Agent()

current_game_state = new_rt_game(3, offline_strategy, action_sequence, public_cards)
state: ShortDeckPokerState = new_game(3, real_time_test=True, public_cards=public_cards)
current_game_state: ShortDeckPokerState = state.load_game_state(
offline_strategy,
action_sequence
)
del offline_strategy
ph_test_node = current_game_state.player_i
for t in trange(1, n_iterations + 1, desc="train iter"):
print(t)
Expand Down

0 comments on commit 3c6bb8b

Please sign in to comment.