Skip to content

Commit e1ed61e

Browse files
committed
[Feature,Example] Add MCTS algorithm and example
ghstack-source-id: 6230e05 Pull Request resolved: #2796
1 parent 8edc29c commit e1ed61e

File tree

5 files changed

+322
-2
lines changed

5 files changed

+322
-2
lines changed

examples/trees/mcts.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import time
7+
8+
import torch
9+
import torchrl
10+
import torchrl.envs
11+
import torchrl.modules.mcts
12+
from tensordict import TensorDict
13+
from torchrl.data import Composite, Unbounded
14+
from torchrl.envs import Transform
15+
16+
pgn_or_fen = "fen"
17+
mask_actions = True
18+
19+
env = torchrl.envs.ChessEnv(
20+
include_pgn=False,
21+
include_fen=True,
22+
include_hash=True,
23+
include_hash_inv=True,
24+
include_san=True,
25+
stateful=True,
26+
mask_actions=mask_actions,
27+
)
28+
29+
30+
class TurnBasedChess(Transform):
31+
def transform_observation_spec(self, obsspec):
32+
obsspec["agent0", "turn"] = Unbounded(dtype=torch.bool, shape=())
33+
obsspec["agent1", "turn"] = Unbounded(dtype=torch.bool, shape=())
34+
return obsspec
35+
36+
def transform_reward_spec(self, reward_spec):
37+
reward = reward_spec["reward"].clone()
38+
del reward_spec["reward"]
39+
return Composite(
40+
agent0=Composite(reward=reward),
41+
agent1=Composite(reward=reward),
42+
)
43+
44+
def _reset(self, _td, td):
45+
td["agent0", "turn"] = td["turn"]
46+
td["agent1", "turn"] = ~td["turn"]
47+
return td
48+
49+
def _step(self, td, td_next):
50+
td_next["agent0", "turn"] = td_next["turn"]
51+
td_next["agent1", "turn"] = ~td_next["turn"]
52+
53+
reward = td_next["reward"]
54+
turn = td["turn"]
55+
56+
if reward == 0.5:
57+
reward = 0
58+
elif reward == 1:
59+
if not turn:
60+
reward = -reward
61+
62+
td_next["agent0", "reward"] = reward
63+
td_next["agent1", "reward"] = -reward
64+
del td_next["reward"]
65+
66+
return td_next
67+
68+
69+
env = env.append_transform(TurnBasedChess())
70+
env.rollout(3)
71+
72+
forest = torchrl.data.MCTSForest()
73+
forest.reward_keys = env.reward_keys
74+
forest.done_keys = env.done_keys
75+
forest.action_keys = env.action_keys
76+
77+
if mask_actions:
78+
forest.observation_keys = [
79+
f"{pgn_or_fen}_hash",
80+
"turn",
81+
"action_mask",
82+
("agent0", "turn"),
83+
("agent1", "turn"),
84+
]
85+
else:
86+
forest.observation_keys = [
87+
f"{pgn_or_fen}_hash",
88+
"turn",
89+
("agent0", "turn"),
90+
("agent1", "turn"),
91+
]
92+
93+
94+
def tree_format_fn(tree):
95+
td = tree.rollout[-1]["next"]
96+
return [
97+
td["san"],
98+
td[pgn_or_fen].split("\n")[-1],
99+
tree.wins,
100+
tree.visits,
101+
]
102+
103+
104+
def get_best_move(fen, mcts_steps, rollout_steps):
105+
root = env.reset(TensorDict({"fen": fen}))
106+
agent_keys = ["agent0", "agent1"]
107+
mcts = torchrl.modules.mcts.MCTS(mcts_steps, rollout_steps, agent_keys=agent_keys)
108+
tree = mcts(forest, root, env)
109+
moves = []
110+
111+
for subtree in tree.subtree:
112+
td = subtree.rollout[0]
113+
san = td["next", "san"]
114+
active_agent = agent_keys[
115+
torch.stack([td[agent]["turn"] for agent in agent_keys]).nonzero()
116+
]
117+
reward_sum = subtree.wins[active_agent, "reward"]
118+
visits = subtree.visits
119+
value_avg = (reward_sum / visits).item()
120+
moves.append((value_avg, san))
121+
122+
moves = sorted(moves, key=lambda x: -x[0])
123+
124+
# print(tree.to_string(tree_format_fn))
125+
126+
print("------------------")
127+
for value_avg, san in moves:
128+
print(f" {value_avg:0.02f} {san}")
129+
print("------------------")
130+
131+
return moves[0][1]
132+
133+
134+
for idx in range(3):
135+
print("==========")
136+
print(idx)
137+
print("==========")
138+
torch.manual_seed(idx)
139+
140+
start_time = time.time()
141+
142+
# White has M1, best move Rd8#. Any other moves lose to M2 or M1.
143+
fen0 = "7k/6pp/7p/7K/8/8/6q1/3R4 w - - 0 1"
144+
assert get_best_move(fen0, 40, 10) == "Rd8#"
145+
146+
# Black has M1, best move Qg6#. Other moves give rough equality or worse.
147+
fen1 = "6qk/2R4p/7K/8/8/8/8/4R3 b - - 1 1"
148+
assert get_best_move(fen1, 40, 10) == "Qg6#"
149+
150+
# White has M2, best move Rxg8+. Any other move loses.
151+
fen2 = "2R3qk/5p1p/7K/8/8/8/5r2/2R5 w - - 0 1"
152+
assert get_best_move(fen2, 600, 10) == "Rxg8+"
153+
154+
# Black has M2, best move Rxg1+. Any other move loses.
155+
fen3 = "2r5/5R2/8/8/8/7k/5P1P/2r3QK b - - 0 1"
156+
assert get_best_move(fen3, 600, 10) == "Rxg1+"
157+
158+
end_time = time.time()
159+
total_time = end_time - start_time
160+
161+
print(f"Took {total_time} s")

torchrl/data/map/tree.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1363,6 +1363,11 @@ def valid_paths(cls, tree: Tree):
13631363
def __len__(self):
13641364
return len(self.data_map)
13651365

1366+
def __contains__(self, root: TensorDictBase):
1367+
if self.node_map is None:
1368+
return False
1369+
return root.select(*self.node_map.in_keys) in self.node_map
1370+
13661371
def to_string(self, td_root, node_format_fn=lambda tree: tree.node_data.to_dict()):
13671372
"""Generates a string representation of a tree in the forest.
13681373

torchrl/envs/custom/chess.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,12 +220,15 @@ def lib(cls):
220220
return chess
221221

222222
_san_moves = []
223+
_san_move_to_index_map = {}
223224

224225
@_classproperty
225226
def san_moves(cls):
226227
if not cls._san_moves:
227228
with open(pathlib.Path(__file__).parent / "san_moves.txt", "r+") as f:
228229
cls._san_moves.extend(f.read().split("\n"))
230+
for idx, san_move in enumerate(cls._san_moves):
231+
cls._san_move_to_index_map[san_move] = idx
229232
return cls._san_moves
230233

231234
def _legal_moves_to_index(
@@ -253,7 +256,7 @@ def _legal_moves_to_index(
253256
board = self.board
254257

255258
indices = torch.tensor(
256-
[self._san_moves.index(board.san(m)) for m in board.legal_moves],
259+
[self._san_move_to_index_map[board.san(m)] for m in board.legal_moves],
257260
dtype=torch.int64,
258261
)
259262
mask = None
@@ -407,7 +410,9 @@ def _reset(self, tensordict=None):
407410
if move is None:
408411
dest.set("san", "<start>")
409412
else:
410-
dest.set("san", self.board.san(move))
413+
prev_board = self.board.copy()
414+
prev_board.pop()
415+
dest.set("san", prev_board.san(move))
411416
if self.include_fen:
412417
dest.set("fen", fen)
413418
if self.include_pgn:

torchrl/modules/mcts/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from .mcts import MCTS

torchrl/modules/mcts/mcts.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Sequence
7+
8+
import torch
9+
import torchrl
10+
from tensordict import TensorDict, TensorDictBase
11+
from tensordict.utils import NestedKey
12+
from torch import nn
13+
14+
from torchrl.data.map import MCTSForest, Tree
15+
from torchrl.envs import EnvBase
16+
17+
18+
class MCTS(nn.Module):
19+
"""Monte-Carlo tree search.
20+
21+
Attributes:
22+
num_traversals (int): Number of times to traverse the tree.
23+
rollout_max_steps (int): Maximum number of steps for each rollout.
24+
25+
Methods:
26+
forward: Runs the tree search.
27+
"""
28+
29+
def __init__(
30+
self,
31+
num_traversals: int,
32+
rollout_max_steps: int | None = None,
33+
agent_keys: Sequence[NestedKey] = None,
34+
turn_key: NestedKey = ("turn",),
35+
):
36+
super().__init__()
37+
self.num_traversals = num_traversals
38+
self.rollout_max_steps = rollout_max_steps
39+
self.agent_keys = agent_keys
40+
self.turn_key = turn_key
41+
42+
def forward(
43+
self,
44+
forest: MCTSForest,
45+
root: TensorDictBase,
46+
env: EnvBase,
47+
) -> Tree:
48+
"""Performs Monte-Carlo tree search in an environment.
49+
50+
Args:
51+
forest (MCTSForest): Forest of the tree to update. If the tree does not
52+
exist yet, it is added.
53+
root (TensorDict): The root step of the tree to update.
54+
env (EnvBase): Environment to performs actions in.
55+
"""
56+
for action in env.all_actions(root):
57+
td = env.step(env.reset(root.clone()).update(action))
58+
forest.extend(td.unsqueeze(0))
59+
60+
tree = forest.get_tree(root)
61+
62+
tree.wins = env.reward_spec.zero()
63+
64+
for subtree in tree.subtree:
65+
subtree.wins = env.reward_spec.zero()
66+
67+
for _ in range(self.num_traversals):
68+
self._traverse_MCTS_one_step(forest, tree, env, self.rollout_max_steps)
69+
70+
return tree
71+
72+
def _traverse_MCTS_one_step(self, forest, tree, env, rollout_max_steps):
73+
done = False
74+
trees_visited = [tree]
75+
76+
while not done:
77+
if tree.subtree is None:
78+
td_tree = tree.rollout[-1]["next"].clone()
79+
80+
if (tree.visits > 0 or tree.parent is None) and not td_tree["done"]:
81+
actions = env.all_actions(td_tree)
82+
subtrees = []
83+
84+
for action in actions:
85+
td = env.step(env.reset(td_tree).update(action))
86+
new_node = torchrl.data.Tree(
87+
rollout=td.unsqueeze(0),
88+
node_data=td["next"].select(*forest.node_map.in_keys),
89+
count=torch.tensor(0),
90+
wins=env.reward_spec.zero(),
91+
)
92+
subtrees.append(new_node)
93+
94+
# NOTE: This whole script runs about 2x faster with lazy stack
95+
# versus eager stack.
96+
tree.subtree = TensorDict.lazy_stack(subtrees)
97+
chosen_idx = torch.randint(0, len(subtrees), ()).item()
98+
rollout_state = subtrees[chosen_idx].rollout[-1]["next"]
99+
100+
else:
101+
rollout_state = td_tree
102+
103+
if rollout_state["done"]:
104+
rollout_reward = rollout_state.select(*env.reward_keys)
105+
else:
106+
rollout = env.rollout(
107+
max_steps=rollout_max_steps,
108+
tensordict=rollout_state,
109+
)
110+
rollout_reward = rollout[-1]["next"].select(*env.reward_keys)
111+
done = True
112+
113+
else:
114+
priorities = self._traversal_priority_UCB1(tree)
115+
chosen_idx = torch.argmax(priorities).item()
116+
tree = tree.subtree[chosen_idx]
117+
trees_visited.append(tree)
118+
119+
for tree in trees_visited:
120+
tree.visits += 1
121+
tree.wins += rollout_reward
122+
123+
def _get_active_agent(self, td: TensorDict) -> str:
124+
turns = torch.stack([td[agent][self.turn_key] for agent in self.agent_keys])
125+
if turns.sum() != 1:
126+
raise ValueError(
127+
"MCTS only supports environments in which it is only one agent's turn at a time."
128+
)
129+
return self.agent_keys[turns.nonzero()]
130+
131+
# TODO: Allow user to specify different priority functions with PR #2358
132+
def _traversal_priority_UCB1(self, tree):
133+
subtree = tree.subtree
134+
visits = subtree.visits
135+
reward_sum = subtree.wins
136+
parent_visits = tree.visits
137+
active_agent = self._get_active_agent(subtree.rollout[0, 0])
138+
reward_sum = reward_sum[active_agent, "reward"].squeeze(-1)
139+
140+
C = 2.0**0.5
141+
priority = (reward_sum + C * torch.sqrt(torch.log(parent_visits))) / visits
142+
priority[visits == 0] = float("inf")
143+
return priority

0 commit comments

Comments
 (0)