diff --git a/examples/trees/mcts.py b/examples/trees/mcts.py new file mode 100644 index 00000000000..9b367331f6f --- /dev/null +++ b/examples/trees/mcts.py @@ -0,0 +1,161 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import time + +import torch +import torchrl +import torchrl.envs +import torchrl.modules.mcts +from tensordict import TensorDict +from torchrl.data import Composite, Unbounded +from torchrl.envs import Transform + +pgn_or_fen = "fen" +mask_actions = True + +env = torchrl.envs.ChessEnv( + include_pgn=False, + include_fen=True, + include_hash=True, + include_hash_inv=True, + include_san=True, + stateful=True, + mask_actions=mask_actions, +) + + +class TurnBasedChess(Transform): + def transform_observation_spec(self, obsspec): + obsspec["agent0", "turn"] = Unbounded(dtype=torch.bool, shape=()) + obsspec["agent1", "turn"] = Unbounded(dtype=torch.bool, shape=()) + return obsspec + + def transform_reward_spec(self, reward_spec): + reward = reward_spec["reward"].clone() + del reward_spec["reward"] + return Composite( + agent0=Composite(reward=reward), + agent1=Composite(reward=reward), + ) + + def _reset(self, _td, td): + td["agent0", "turn"] = td["turn"] + td["agent1", "turn"] = ~td["turn"] + return td + + def _step(self, td, td_next): + td_next["agent0", "turn"] = td_next["turn"] + td_next["agent1", "turn"] = ~td_next["turn"] + + reward = td_next["reward"] + turn = td["turn"] + + if reward == 0.5: + reward = 0 + elif reward == 1: + if not turn: + reward = -reward + + td_next["agent0", "reward"] = reward + td_next["agent1", "reward"] = -reward + del td_next["reward"] + + return td_next + + +env = env.append_transform(TurnBasedChess()) +env.rollout(3) + +forest = torchrl.data.MCTSForest() +forest.reward_keys = env.reward_keys +forest.done_keys = env.done_keys +forest.action_keys = env.action_keys + +if mask_actions: + forest.observation_keys = [ + f"{pgn_or_fen}_hash", + "turn", + "action_mask", + ("agent0", "turn"), + ("agent1", "turn"), + ] +else: + forest.observation_keys = [ + f"{pgn_or_fen}_hash", + "turn", + ("agent0", "turn"), + ("agent1", "turn"), + ] + + +def tree_format_fn(tree): + td = tree.rollout[-1]["next"] + return [ + td["san"], + td[pgn_or_fen].split("\n")[-1], + tree.wins, + tree.visits, + ] + + +def get_best_move(fen, mcts_steps, rollout_steps): + root = env.reset(TensorDict({"fen": fen})) + agent_keys = ["agent0", "agent1"] + mcts = torchrl.modules.mcts.MCTS(mcts_steps, rollout_steps, agent_keys=agent_keys) + tree = mcts(forest, root, env) + moves = [] + + for subtree in tree.subtree: + td = subtree.rollout[0] + san = td["next", "san"] + active_agent = agent_keys[ + torch.stack([td[agent]["turn"] for agent in agent_keys]).nonzero() + ] + reward_sum = subtree.wins[active_agent, "reward"] + visits = subtree.visits + value_avg = (reward_sum / visits).item() + moves.append((value_avg, san)) + + moves = sorted(moves, key=lambda x: -x[0]) + + # print(tree.to_string(tree_format_fn)) + + print("------------------") + for value_avg, san in moves: + print(f" {value_avg:0.02f} {san}") + print("------------------") + + return moves[0][1] + + +for idx in range(3): + print("==========") + print(idx) + print("==========") + torch.manual_seed(idx) + + start_time = time.time() + + # White has M1, best move Rd8#. Any other moves lose to M2 or M1. + fen0 = "7k/6pp/7p/7K/8/8/6q1/3R4 w - - 0 1" + assert get_best_move(fen0, 40, 10) == "Rd8#" + + # Black has M1, best move Qg6#. Other moves give rough equality or worse. + fen1 = "6qk/2R4p/7K/8/8/8/8/4R3 b - - 1 1" + assert get_best_move(fen1, 40, 10) == "Qg6#" + + # White has M2, best move Rxg8+. Any other move loses. + fen2 = "2R3qk/5p1p/7K/8/8/8/5r2/2R5 w - - 0 1" + assert get_best_move(fen2, 600, 10) == "Rxg8+" + + # Black has M2, best move Rxg1+. Any other move loses. + fen3 = "2r5/5R2/8/8/8/7k/5P1P/2r3QK b - - 0 1" + assert get_best_move(fen3, 600, 10) == "Rxg1+" + + end_time = time.time() + total_time = end_time - start_time + + print(f"Took {total_time} s") diff --git a/torchrl/data/map/tree.py b/torchrl/data/map/tree.py index a45ab58662b..7f2957aefd6 100644 --- a/torchrl/data/map/tree.py +++ b/torchrl/data/map/tree.py @@ -1363,6 +1363,11 @@ def valid_paths(cls, tree: Tree): def __len__(self): return len(self.data_map) + def __contains__(self, root: TensorDictBase): + if self.node_map is None: + return False + return root.select(*self.node_map.in_keys) in self.node_map + def to_string(self, td_root, node_format_fn=lambda tree: tree.node_data.to_dict()): """Generates a string representation of a tree in the forest. diff --git a/torchrl/envs/custom/chess.py b/torchrl/envs/custom/chess.py index 23ccbc73b0a..194af842724 100644 --- a/torchrl/envs/custom/chess.py +++ b/torchrl/envs/custom/chess.py @@ -220,12 +220,15 @@ def lib(cls): return chess _san_moves = [] + _san_move_to_index_map = {} @_classproperty def san_moves(cls): if not cls._san_moves: with open(pathlib.Path(__file__).parent / "san_moves.txt", "r+") as f: cls._san_moves.extend(f.read().split("\n")) + for idx, san_move in enumerate(cls._san_moves): + cls._san_move_to_index_map[san_move] = idx return cls._san_moves def _legal_moves_to_index( @@ -253,7 +256,7 @@ def _legal_moves_to_index( board = self.board indices = torch.tensor( - [self._san_moves.index(board.san(m)) for m in board.legal_moves], + [self._san_move_to_index_map[board.san(m)] for m in board.legal_moves], dtype=torch.int64, ) mask = None @@ -407,7 +410,9 @@ def _reset(self, tensordict=None): if move is None: dest.set("san", "") else: - dest.set("san", self.board.san(move)) + prev_board = self.board.copy() + prev_board.pop() + dest.set("san", prev_board.san(move)) if self.include_fen: dest.set("fen", fen) if self.include_pgn: diff --git a/torchrl/modules/mcts/__init__.py b/torchrl/modules/mcts/__init__.py new file mode 100644 index 00000000000..2e90d338f4a --- /dev/null +++ b/torchrl/modules/mcts/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .mcts import MCTS diff --git a/torchrl/modules/mcts/mcts.py b/torchrl/modules/mcts/mcts.py new file mode 100644 index 00000000000..0554349d384 --- /dev/null +++ b/torchrl/modules/mcts/mcts.py @@ -0,0 +1,143 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Sequence + +import torch +import torchrl +from tensordict import TensorDict, TensorDictBase +from tensordict.utils import NestedKey +from torch import nn + +from torchrl.data.map import MCTSForest, Tree +from torchrl.envs import EnvBase + + +class MCTS(nn.Module): + """Monte-Carlo tree search. + + Attributes: + num_traversals (int): Number of times to traverse the tree. + rollout_max_steps (int): Maximum number of steps for each rollout. + + Methods: + forward: Runs the tree search. + """ + + def __init__( + self, + num_traversals: int, + rollout_max_steps: int | None = None, + agent_keys: Sequence[NestedKey] = None, + turn_key: NestedKey = ("turn",), + ): + super().__init__() + self.num_traversals = num_traversals + self.rollout_max_steps = rollout_max_steps + self.agent_keys = agent_keys + self.turn_key = turn_key + + def forward( + self, + forest: MCTSForest, + root: TensorDictBase, + env: EnvBase, + ) -> Tree: + """Performs Monte-Carlo tree search in an environment. + + Args: + forest (MCTSForest): Forest of the tree to update. If the tree does not + exist yet, it is added. + root (TensorDict): The root step of the tree to update. + env (EnvBase): Environment to performs actions in. + """ + for action in env.all_actions(root): + td = env.step(env.reset(root.clone()).update(action)) + forest.extend(td.unsqueeze(0)) + + tree = forest.get_tree(root) + + tree.wins = env.reward_spec.zero() + + for subtree in tree.subtree: + subtree.wins = env.reward_spec.zero() + + for _ in range(self.num_traversals): + self._traverse_MCTS_one_step(forest, tree, env, self.rollout_max_steps) + + return tree + + def _traverse_MCTS_one_step(self, forest, tree, env, rollout_max_steps): + done = False + trees_visited = [tree] + + while not done: + if tree.subtree is None: + td_tree = tree.rollout[-1]["next"].clone() + + if (tree.visits > 0 or tree.parent is None) and not td_tree["done"]: + actions = env.all_actions(td_tree) + subtrees = [] + + for action in actions: + td = env.step(env.reset(td_tree).update(action)) + new_node = torchrl.data.Tree( + rollout=td.unsqueeze(0), + node_data=td["next"].select(*forest.node_map.in_keys), + count=torch.tensor(0), + wins=env.reward_spec.zero(), + ) + subtrees.append(new_node) + + # NOTE: This whole script runs about 2x faster with lazy stack + # versus eager stack. + tree.subtree = TensorDict.lazy_stack(subtrees) + chosen_idx = torch.randint(0, len(subtrees), ()).item() + rollout_state = subtrees[chosen_idx].rollout[-1]["next"] + + else: + rollout_state = td_tree + + if rollout_state["done"]: + rollout_reward = rollout_state.select(*env.reward_keys) + else: + rollout = env.rollout( + max_steps=rollout_max_steps, + tensordict=rollout_state, + ) + rollout_reward = rollout[-1]["next"].select(*env.reward_keys) + done = True + + else: + priorities = self._traversal_priority_UCB1(tree) + chosen_idx = torch.argmax(priorities).item() + tree = tree.subtree[chosen_idx] + trees_visited.append(tree) + + for tree in trees_visited: + tree.visits += 1 + tree.wins += rollout_reward + + def _get_active_agent(self, td: TensorDict) -> str: + turns = torch.stack([td[agent][self.turn_key] for agent in self.agent_keys]) + if turns.sum() != 1: + raise ValueError( + "MCTS only supports environments in which it is only one agent's turn at a time." + ) + return self.agent_keys[turns.nonzero()] + + # TODO: Allow user to specify different priority functions with PR #2358 + def _traversal_priority_UCB1(self, tree): + subtree = tree.subtree + visits = subtree.visits + reward_sum = subtree.wins + parent_visits = tree.visits + active_agent = self._get_active_agent(subtree.rollout[0, 0]) + reward_sum = reward_sum[active_agent, "reward"].squeeze(-1) + + C = 2.0**0.5 + priority = (reward_sum + C * torch.sqrt(torch.log(parent_visits))) / visits + priority[visits == 0] = float("inf") + return priority