Skip to content

[Feature,Example] Add MCTS algorithm and example #2796

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: gh/kurtamohler/5/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 161 additions & 0 deletions examples/trees/mcts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import time

import torch
import torchrl
import torchrl.envs
import torchrl.modules.mcts
from tensordict import TensorDict
from torchrl.data import Composite, Unbounded
from torchrl.envs import Transform

pgn_or_fen = "fen"
mask_actions = True

env = torchrl.envs.ChessEnv(
include_pgn=False,
include_fen=True,
include_hash=True,
include_hash_inv=True,
include_san=True,
stateful=True,
mask_actions=mask_actions,
)


class TurnBasedChess(Transform):
def transform_observation_spec(self, obsspec):
obsspec["agent0", "turn"] = Unbounded(dtype=torch.bool, shape=())
obsspec["agent1", "turn"] = Unbounded(dtype=torch.bool, shape=())
return obsspec

def transform_reward_spec(self, reward_spec):
reward = reward_spec["reward"].clone()
del reward_spec["reward"]
return Composite(
agent0=Composite(reward=reward),
agent1=Composite(reward=reward),
)

def _reset(self, _td, td):
td["agent0", "turn"] = td["turn"]
td["agent1", "turn"] = ~td["turn"]
return td

def _step(self, td, td_next):
td_next["agent0", "turn"] = td_next["turn"]
td_next["agent1", "turn"] = ~td_next["turn"]

reward = td_next["reward"]
turn = td["turn"]

if reward == 0.5:
reward = 0
elif reward == 1:
if not turn:
reward = -reward

td_next["agent0", "reward"] = reward
td_next["agent1", "reward"] = -reward
del td_next["reward"]

return td_next


env = env.append_transform(TurnBasedChess())
env.rollout(3)

forest = torchrl.data.MCTSForest()
forest.reward_keys = env.reward_keys
forest.done_keys = env.done_keys
forest.action_keys = env.action_keys

if mask_actions:
forest.observation_keys = [
f"{pgn_or_fen}_hash",
"turn",
"action_mask",
("agent0", "turn"),
("agent1", "turn"),
]
else:
forest.observation_keys = [
f"{pgn_or_fen}_hash",
"turn",
("agent0", "turn"),
("agent1", "turn"),
]


def tree_format_fn(tree):
td = tree.rollout[-1]["next"]
return [
td["san"],
td[pgn_or_fen].split("\n")[-1],
tree.wins,
tree.visits,
]


def get_best_move(fen, mcts_steps, rollout_steps):
root = env.reset(TensorDict({"fen": fen}))
agent_keys = ["agent0", "agent1"]
mcts = torchrl.modules.mcts.MCTS(mcts_steps, rollout_steps, agent_keys=agent_keys)
tree = mcts(forest, root, env)
moves = []

for subtree in tree.subtree:
td = subtree.rollout[0]
san = td["next", "san"]
active_agent = agent_keys[
torch.stack([td[agent]["turn"] for agent in agent_keys]).nonzero()
]
reward_sum = subtree.wins[active_agent, "reward"]
visits = subtree.visits
value_avg = (reward_sum / visits).item()
moves.append((value_avg, san))

moves = sorted(moves, key=lambda x: -x[0])

# print(tree.to_string(tree_format_fn))

print("------------------")
for value_avg, san in moves:
print(f" {value_avg:0.02f} {san}")
print("------------------")

return moves[0][1]


for idx in range(3):
print("==========")
print(idx)
print("==========")
torch.manual_seed(idx)

start_time = time.time()

# White has M1, best move Rd8#. Any other moves lose to M2 or M1.
fen0 = "7k/6pp/7p/7K/8/8/6q1/3R4 w - - 0 1"
assert get_best_move(fen0, 40, 10) == "Rd8#"

# Black has M1, best move Qg6#. Other moves give rough equality or worse.
fen1 = "6qk/2R4p/7K/8/8/8/8/4R3 b - - 1 1"
assert get_best_move(fen1, 40, 10) == "Qg6#"

# White has M2, best move Rxg8+. Any other move loses.
fen2 = "2R3qk/5p1p/7K/8/8/8/5r2/2R5 w - - 0 1"
assert get_best_move(fen2, 600, 10) == "Rxg8+"

# Black has M2, best move Rxg1+. Any other move loses.
fen3 = "2r5/5R2/8/8/8/7k/5P1P/2r3QK b - - 0 1"
assert get_best_move(fen3, 600, 10) == "Rxg1+"

end_time = time.time()
total_time = end_time - start_time

print(f"Took {total_time} s")
5 changes: 5 additions & 0 deletions torchrl/data/map/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1363,6 +1363,11 @@ def valid_paths(cls, tree: Tree):
def __len__(self):
return len(self.data_map)

def __contains__(self, root: TensorDictBase):
if self.node_map is None:
return False
return root.select(*self.node_map.in_keys) in self.node_map

def to_string(self, td_root, node_format_fn=lambda tree: tree.node_data.to_dict()):
"""Generates a string representation of a tree in the forest.

Expand Down
9 changes: 7 additions & 2 deletions torchrl/envs/custom/chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,12 +220,15 @@ def lib(cls):
return chess

_san_moves = []
_san_move_to_index_map = {}

@_classproperty
def san_moves(cls):
if not cls._san_moves:
with open(pathlib.Path(__file__).parent / "san_moves.txt", "r+") as f:
cls._san_moves.extend(f.read().split("\n"))
for idx, san_move in enumerate(cls._san_moves):
cls._san_move_to_index_map[san_move] = idx
return cls._san_moves

def _legal_moves_to_index(
Expand Down Expand Up @@ -253,7 +256,7 @@ def _legal_moves_to_index(
board = self.board

indices = torch.tensor(
[self._san_moves.index(board.san(m)) for m in board.legal_moves],
[self._san_move_to_index_map[board.san(m)] for m in board.legal_moves],
dtype=torch.int64,
)
mask = None
Expand Down Expand Up @@ -407,7 +410,9 @@ def _reset(self, tensordict=None):
if move is None:
dest.set("san", "<start>")
else:
dest.set("san", self.board.san(move))
prev_board = self.board.copy()
prev_board.pop()
dest.set("san", prev_board.san(move))
if self.include_fen:
dest.set("fen", fen)
if self.include_pgn:
Expand Down
6 changes: 6 additions & 0 deletions torchrl/modules/mcts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .mcts import MCTS
143 changes: 143 additions & 0 deletions torchrl/modules/mcts/mcts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Sequence

import torch
import torchrl
from tensordict import TensorDict, TensorDictBase
from tensordict.utils import NestedKey
from torch import nn

from torchrl.data.map import MCTSForest, Tree
from torchrl.envs import EnvBase


class MCTS(nn.Module):
"""Monte-Carlo tree search.

Attributes:
num_traversals (int): Number of times to traverse the tree.
rollout_max_steps (int): Maximum number of steps for each rollout.

Methods:
forward: Runs the tree search.
"""

def __init__(
self,
num_traversals: int,
rollout_max_steps: int | None = None,
agent_keys: Sequence[NestedKey] = None,
turn_key: NestedKey = ("turn",),
):
super().__init__()
self.num_traversals = num_traversals
self.rollout_max_steps = rollout_max_steps
self.agent_keys = agent_keys
self.turn_key = turn_key

def forward(
self,
forest: MCTSForest,
root: TensorDictBase,
env: EnvBase,
) -> Tree:
"""Performs Monte-Carlo tree search in an environment.

Args:
forest (MCTSForest): Forest of the tree to update. If the tree does not
exist yet, it is added.
root (TensorDict): The root step of the tree to update.
env (EnvBase): Environment to performs actions in.
"""
for action in env.all_actions(root):
td = env.step(env.reset(root.clone()).update(action))
forest.extend(td.unsqueeze(0))

tree = forest.get_tree(root)

tree.wins = env.reward_spec.zero()

for subtree in tree.subtree:
subtree.wins = env.reward_spec.zero()

for _ in range(self.num_traversals):
self._traverse_MCTS_one_step(forest, tree, env, self.rollout_max_steps)

return tree

def _traverse_MCTS_one_step(self, forest, tree, env, rollout_max_steps):
done = False
trees_visited = [tree]

while not done:
if tree.subtree is None:
td_tree = tree.rollout[-1]["next"].clone()

if (tree.visits > 0 or tree.parent is None) and not td_tree["done"]:
actions = env.all_actions(td_tree)
subtrees = []

for action in actions:
td = env.step(env.reset(td_tree).update(action))
new_node = torchrl.data.Tree(
rollout=td.unsqueeze(0),
node_data=td["next"].select(*forest.node_map.in_keys),
count=torch.tensor(0),
wins=env.reward_spec.zero(),
)
subtrees.append(new_node)

# NOTE: This whole script runs about 2x faster with lazy stack
# versus eager stack.
tree.subtree = TensorDict.lazy_stack(subtrees)
chosen_idx = torch.randint(0, len(subtrees), ()).item()
rollout_state = subtrees[chosen_idx].rollout[-1]["next"]

else:
rollout_state = td_tree

if rollout_state["done"]:
rollout_reward = rollout_state.select(*env.reward_keys)
else:
rollout = env.rollout(
max_steps=rollout_max_steps,
tensordict=rollout_state,
)
rollout_reward = rollout[-1]["next"].select(*env.reward_keys)
done = True

else:
priorities = self._traversal_priority_UCB1(tree)
chosen_idx = torch.argmax(priorities).item()
tree = tree.subtree[chosen_idx]
trees_visited.append(tree)

for tree in trees_visited:
tree.visits += 1
tree.wins += rollout_reward

def _get_active_agent(self, td: TensorDict) -> str:
turns = torch.stack([td[agent][self.turn_key] for agent in self.agent_keys])
if turns.sum() != 1:
raise ValueError(
"MCTS only supports environments in which it is only one agent's turn at a time."
)
return self.agent_keys[turns.nonzero()]

# TODO: Allow user to specify different priority functions with PR #2358
def _traversal_priority_UCB1(self, tree):
subtree = tree.subtree
visits = subtree.visits
reward_sum = subtree.wins
parent_visits = tree.visits
active_agent = self._get_active_agent(subtree.rollout[0, 0])
reward_sum = reward_sum[active_agent, "reward"].squeeze(-1)

C = 2.0**0.5
priority = (reward_sum + C * torch.sqrt(torch.log(parent_visits))) / visits
priority[visits == 0] = float("inf")
return priority
Loading