Skip to content

Commit c1c3fb7

Browse files
committed
[Feature,Example] Add MCTS algorithm and example
ghstack-source-id: 220d7ed Pull Request resolved: #2796
1 parent a31dca3 commit c1c3fb7

File tree

5 files changed

+279
-2
lines changed

5 files changed

+279
-2
lines changed

examples/trees/mcts.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import time
7+
8+
import torch
9+
import torchrl
10+
import torchrl.envs
11+
import torchrl.modules.mcts
12+
from tensordict import TensorDict
13+
14+
pgn_or_fen = "fen"
15+
mask_actions = True
16+
17+
env = torchrl.envs.ChessEnv(
18+
include_pgn=False,
19+
include_fen=True,
20+
include_hash=True,
21+
include_hash_inv=True,
22+
include_san=True,
23+
stateful=True,
24+
mask_actions=mask_actions,
25+
)
26+
27+
28+
class TransformReward:
29+
def __call__(self, td):
30+
if "reward" not in td:
31+
return td
32+
33+
reward = td["reward"]
34+
35+
if reward == 0.5:
36+
reward = 0
37+
elif reward == 1 and td["turn"]:
38+
reward = -reward
39+
40+
td["reward"] = reward
41+
return td
42+
43+
44+
# ChessEnv sets the reward to 0.5 for a draw and 1 for a win for either player.
45+
# Need to transform the reward to be:
46+
# white win = 1
47+
# draw = 0
48+
# black win = -1
49+
transform_reward = TransformReward()
50+
env = env.append_transform(transform_reward)
51+
52+
forest = torchrl.data.MCTSForest()
53+
forest.reward_keys = env.reward_keys
54+
forest.done_keys = env.done_keys
55+
forest.action_keys = env.action_keys
56+
57+
if mask_actions:
58+
forest.observation_keys = [f"{pgn_or_fen}_hash", "turn", "action_mask"]
59+
else:
60+
forest.observation_keys = [f"{pgn_or_fen}_hash", "turn"]
61+
62+
63+
def tree_format_fn(tree):
64+
td = tree.rollout[-1]["next"]
65+
return [
66+
td["san"],
67+
td[pgn_or_fen].split("\n")[-1],
68+
tree.wins,
69+
tree.visits,
70+
]
71+
72+
73+
def get_best_move(fen, mcts_steps, rollout_steps):
74+
root = env.reset(TensorDict({"fen": fen}))
75+
mcts = torchrl.modules.mcts.MCTS(mcts_steps, rollout_steps)
76+
tree = mcts(forest, root, env)
77+
moves = []
78+
79+
for subtree in tree.subtree:
80+
san = subtree.rollout[0]["next", "san"]
81+
reward_sum = subtree.wins
82+
visits = subtree.visits
83+
value_avg = (reward_sum / visits).item()
84+
if not root["turn"]:
85+
value_avg = -value_avg
86+
moves.append((value_avg, san))
87+
88+
moves = sorted(moves, key=lambda x: -x[0])
89+
90+
# print(tree.to_string(tree_format_fn))
91+
92+
print("------------------")
93+
for value_avg, san in moves:
94+
print(f" {value_avg:0.02f} {san}")
95+
print("------------------")
96+
97+
return moves[0][1]
98+
99+
100+
for idx in range(30):
101+
print("==========")
102+
print(idx)
103+
print("==========")
104+
torch.manual_seed(idx)
105+
106+
start_time = time.time()
107+
108+
# White has M1, best move Rd8#. Any other moves lose to M2 or M1.
109+
fen0 = "7k/6pp/7p/7K/8/8/6q1/3R4 w - - 0 1"
110+
assert get_best_move(fen0, 40, 10) == "Rd8#"
111+
112+
# Black has M1, best move Qg6#. Other moves give rough equality or worse.
113+
fen1 = "6qk/2R4p/7K/8/8/8/8/4R3 b - - 1 1"
114+
assert get_best_move(fen1, 40, 10) == "Qg6#"
115+
116+
# White has M2, best move Rxg8+. Any other move loses.
117+
fen2 = "2R3qk/5p1p/7K/8/8/8/5r2/2R5 w - - 0 1"
118+
assert get_best_move(fen2, 600, 10) == "Rxg8+"
119+
120+
# Black has M2, best move Rxg1+. Any other move loses.
121+
fen3 = "2r5/5R2/8/8/8/7k/5P1P/2r3QK b - - 0 1"
122+
assert get_best_move(fen3, 600, 10) == "Rxg1+"
123+
124+
end_time = time.time()
125+
total_time = end_time - start_time
126+
127+
print(f"Took {total_time} s")

torchrl/data/map/tree.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1363,6 +1363,11 @@ def valid_paths(cls, tree: Tree):
13631363
def __len__(self):
13641364
return len(self.data_map)
13651365

1366+
def __contains__(self, root: TensorDictBase):
1367+
if self.node_map is None:
1368+
return False
1369+
return root.select(*self.node_map.in_keys) in self.node_map
1370+
13661371
def to_string(self, td_root, node_format_fn=lambda tree: tree.node_data.to_dict()):
13671372
"""Generates a string representation of a tree in the forest.
13681373

torchrl/envs/custom/chess.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,12 +222,15 @@ def lib(cls):
222222
return chess
223223

224224
_san_moves = []
225+
_san_move_to_index_map = {}
225226

226227
@_classproperty
227228
def san_moves(cls):
228229
if not cls._san_moves:
229230
with open(pathlib.Path(__file__).parent / "san_moves.txt", "r+") as f:
230231
cls._san_moves.extend(f.read().split("\n"))
232+
for idx, san_move in enumerate(cls._san_moves):
233+
cls._san_move_to_index_map[san_move] = idx
231234
return cls._san_moves
232235

233236
def _legal_moves_to_index(
@@ -255,7 +258,7 @@ def _legal_moves_to_index(
255258
board = self.board
256259

257260
indices = torch.tensor(
258-
[self._san_moves.index(board.san(m)) for m in board.legal_moves],
261+
[self._san_move_to_index_map[board.san(m)] for m in board.legal_moves],
259262
dtype=torch.int64,
260263
)
261264
mask = None
@@ -409,7 +412,9 @@ def _reset(self, tensordict=None):
409412
if move is None:
410413
dest.set("san", "<start>")
411414
else:
412-
dest.set("san", self.board.san(move))
415+
prev_board = self.board.copy()
416+
prev_board.pop()
417+
dest.set("san", prev_board.san(move))
413418
if self.include_fen:
414419
dest.set("fen", fen)
415420
if self.include_pgn:

torchrl/modules/mcts/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from .mcts import MCTS

torchrl/modules/mcts/mcts.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
import torchrl
8+
from tensordict import TensorDict, TensorDictBase
9+
from tensordict.nn import TensorDictModuleBase
10+
11+
from torchrl.data.map import MCTSForest, Tree
12+
from torchrl.envs import EnvBase
13+
14+
C = 2.0**0.5
15+
16+
17+
class MCTS(TensorDictModuleBase):
18+
"""Monte-Carlo tree search.
19+
20+
Attributes:
21+
num_traversals (int): Number of times to traverse the tree.
22+
rollout_max_steps (int): Maximum number of steps for each rollout.
23+
24+
Methods:
25+
forward: Runs the tree search.
26+
"""
27+
28+
def __init__(
29+
self,
30+
num_traversals: int,
31+
rollout_max_steps: int | None = None,
32+
):
33+
super().__init__()
34+
self.num_traversals = num_traversals
35+
self.rollout_max_steps = rollout_max_steps
36+
37+
def forward(
38+
self,
39+
forest: MCTSForest,
40+
root: TensorDictBase,
41+
env: EnvBase,
42+
) -> Tree:
43+
"""Performs Monte-Carlo tree search in an environment.
44+
45+
Args:
46+
forest (MCTSForest): Forest of the tree to update. If the tree does not
47+
exist yet, it is added.
48+
root (TensorDict): The root step of the tree to update.
49+
env (EnvBase): Environment to performs actions in.
50+
"""
51+
for action in env.all_actions(root):
52+
td = env.step(env.reset(root.clone()).update(action))
53+
forest.extend(td.unsqueeze(0))
54+
55+
tree = forest.get_tree(root)
56+
57+
tree.wins = torch.zeros_like(td["next", env.reward_key])
58+
for subtree in tree.subtree:
59+
subtree.wins = torch.zeros_like(td["next", env.reward_key])
60+
61+
for _ in range(self.num_traversals):
62+
self._traverse_MCTS_one_step(forest, tree, env, self.rollout_max_steps)
63+
64+
return tree
65+
66+
def _traverse_MCTS_one_step(self, forest, tree, env, rollout_max_steps):
67+
done = False
68+
trees_visited = [tree]
69+
70+
while not done:
71+
if tree.subtree is None:
72+
td_tree = tree.rollout[-1]["next"].clone()
73+
74+
if (tree.visits > 0 or tree.parent is None) and not td_tree["done"]:
75+
actions = env.all_actions(td_tree)
76+
subtrees = []
77+
78+
for action in actions:
79+
td = env.step(env.reset(td_tree).update(action))
80+
new_node = torchrl.data.Tree(
81+
rollout=td.unsqueeze(0),
82+
node_data=td["next"].select(*forest.node_map.in_keys),
83+
count=torch.tensor(0),
84+
wins=torch.zeros_like(td["next", env.reward_key]),
85+
)
86+
subtrees.append(new_node)
87+
88+
# NOTE: This whole script runs about 2x faster with lazy stack
89+
# versus eager stack.
90+
tree.subtree = TensorDict.lazy_stack(subtrees)
91+
chosen_idx = torch.randint(0, len(subtrees), ()).item()
92+
rollout_state = subtrees[chosen_idx].rollout[-1]["next"]
93+
94+
else:
95+
rollout_state = td_tree
96+
97+
if rollout_state["done"]:
98+
rollout_reward = rollout_state[env.reward_key]
99+
else:
100+
rollout = env.rollout(
101+
max_steps=rollout_max_steps,
102+
tensordict=rollout_state,
103+
)
104+
rollout_reward = rollout[-1]["next", env.reward_key]
105+
done = True
106+
107+
else:
108+
priorities = self._traversal_priority_UCB1(tree)
109+
chosen_idx = torch.argmax(priorities).item()
110+
tree = tree.subtree[chosen_idx]
111+
trees_visited.append(tree)
112+
113+
for tree in trees_visited:
114+
tree.visits += 1
115+
tree.wins += rollout_reward
116+
117+
# TODO: Allow user to specify different priority functions with PR #2358
118+
def _traversal_priority_UCB1(self, tree):
119+
subtree = tree.subtree
120+
visits = subtree.visits
121+
reward_sum = subtree.wins
122+
123+
# If it's black's turn, flip the reward, since black wants to optimize for
124+
# the lowest reward, not highest.
125+
# TODO: Need a more generic way to do this, since not all use cases of MCTS
126+
# will be two player turn based games.
127+
if not subtree.rollout[0, 0]["turn"]:
128+
reward_sum = -reward_sum
129+
130+
parent_visits = tree.visits
131+
reward_sum = reward_sum.squeeze(-1)
132+
priority = (reward_sum + C * torch.sqrt(torch.log(parent_visits))) / visits
133+
priority[visits == 0] = float("inf")
134+
return priority

0 commit comments

Comments
 (0)