Skip to content

Commit d009835

Browse files
author
Vincent Moens
committed
[Example] RNN-based policy example
ghstack-source-id: ef0087e Pull Request resolved: #2675
1 parent ab4250e commit d009835

File tree

2 files changed

+207
-2
lines changed

2 files changed

+207
-2
lines changed

examples/agents/recurrent_actor.py

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
"""
8+
This code exemplifies how an actor that uses a RNN backbone can be built.
9+
10+
It is based on snippets from the DQN with RNN tutorial.
11+
12+
There are two main APIs to be aware of when using RNNs, and dedicated notes regarding these can be found at the end
13+
of this example: the `set_recurrent_mode` context manager, and the `make_tensordict_primer` method.
14+
15+
"""
16+
from collections import OrderedDict
17+
18+
import torch
19+
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
20+
from torch import nn
21+
22+
from torchrl.envs import (
23+
Compose,
24+
GrayScale,
25+
GymEnv,
26+
InitTracker,
27+
ObservationNorm,
28+
Resize,
29+
RewardScaling,
30+
StepCounter,
31+
ToTensorImage,
32+
TransformedEnv,
33+
)
34+
from torchrl.modules import ConvNet, LSTMModule, MLP, QValueModule, set_recurrent_mode
35+
36+
# Define the device to use for computations (GPU if available, otherwise CPU)
37+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
38+
39+
# Create a transformed environment using the CartPole-v1 gym environment
40+
env = TransformedEnv(
41+
GymEnv("CartPole-v1", from_pixels=True, device=device),
42+
# Apply a series of transformations to the environment:
43+
# 1. Convert observations to tensor images
44+
# 2. Convert images to grayscale
45+
# 3. Resize images to 84x84 pixels
46+
# 4. Keep track of the step count
47+
# 5. Initialize a tracker for the environment
48+
# 6. Scale rewards by a factor of 0.1
49+
# 7. Normalize observations to have zero mean and unit variance (we'll adapt that dynamically later)
50+
Compose(
51+
ToTensorImage(),
52+
GrayScale(),
53+
Resize(84, 84),
54+
StepCounter(),
55+
InitTracker(),
56+
RewardScaling(loc=0.0, scale=0.1),
57+
ObservationNorm(standard_normal=True, in_keys=["pixels"]),
58+
),
59+
)
60+
61+
# Initialize the normalization statistics for the observation norm transform
62+
env.transform[-1].init_stats(1000, reduce_dim=[0, 1, 2], cat_dim=0, keep_dims=[0])
63+
64+
# Reset the environment to get an initial observation
65+
td = env.reset()
66+
67+
# Define a feature extractor module that takes pixel observations as input
68+
# and outputs an embedding vector
69+
feature = Mod(
70+
ConvNet(
71+
num_cells=[32, 32, 64],
72+
squeeze_output=True,
73+
aggregator_class=nn.AdaptiveAvgPool2d,
74+
aggregator_kwargs={"output_size": (1, 1)},
75+
device=device,
76+
),
77+
in_keys=["pixels"],
78+
out_keys=["embed"],
79+
)
80+
81+
# Get the shape of the embedding vector output by the feature extractor
82+
with torch.no_grad():
83+
n_cells = feature(env.reset())["embed"].shape[-1]
84+
85+
# Define an LSTM module that takes the embedding vector as input and outputs
86+
# a new embedding vector
87+
lstm = LSTMModule(
88+
input_size=n_cells,
89+
hidden_size=128,
90+
device=device,
91+
in_key="embed",
92+
out_key="embed",
93+
)
94+
95+
# Define a multi-layer perceptron (MLP) module that takes the LSTM output as
96+
# input and outputs action values
97+
mlp = MLP(
98+
out_features=2,
99+
num_cells=[
100+
64,
101+
],
102+
device=device,
103+
)
104+
105+
# Initialize the bias of the last layer of the MLP to zero
106+
mlp[-1].bias.data.fill_(0.0)
107+
108+
# Wrap the MLP in a TensorDictModule to handle input/output keys
109+
mlp = Mod(mlp, in_keys=["embed"], out_keys=["action_value"])
110+
111+
# Define a Q-value module that computes the Q-value of the current state
112+
qval = QValueModule(action_space=None, spec=env.action_spec)
113+
114+
# Add a TensorDictPrimer to the environment to ensure that the policy is aware
115+
# of the supplementary inputs and outputs (recurrent states) during rollout execution
116+
# This is necessary when using batched environments or parallel data collection
117+
env.append_transform(lstm.make_tensordict_primer())
118+
119+
# Create a sequential module that combines the feature extractor, LSTM, MLP, and Q-value modules
120+
policy = Seq(OrderedDict(feature=feature, lstm=lstm, mlp=mlp, qval=qval))
121+
122+
# Roll out the policy in the environment for 100 steps
123+
rollout = env.rollout(100, policy)
124+
print(rollout)
125+
126+
# Print result:
127+
#
128+
# TensorDict(
129+
# fields={
130+
# action: Tensor(shape=torch.Size([10, 2]), device=cpu, dtype=torch.int64, is_shared=False),
131+
# action_value: Tensor(shape=torch.Size([10, 2]), device=cpu, dtype=torch.float32, is_shared=False),
132+
# chosen_action_value: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
133+
# done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
134+
# embed: Tensor(shape=torch.Size([10, 128]), device=cpu, dtype=torch.float32, is_shared=False),
135+
# is_init: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
136+
# next: TensorDict(
137+
# fields={
138+
# done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
139+
# is_init: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
140+
# pixels: Tensor(shape=torch.Size([10, 1, 84, 84]), device=cpu, dtype=torch.float32, is_shared=False),
141+
# recurrent_state_c: Tensor(shape=torch.Size([10, 1, 128]), device=cpu, dtype=torch.float32, is_shared=False),
142+
# recurrent_state_h: Tensor(shape=torch.Size([10, 1, 128]), device=cpu, dtype=torch.float32, is_shared=False),
143+
# reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
144+
# step_count: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False),
145+
# terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
146+
# truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
147+
# batch_size=torch.Size([10]),
148+
# device=cpu,
149+
# is_shared=False),
150+
# pixels: Tensor(shape=torch.Size([10, 1, 84, 84]), device=cpu, dtype=torch.float32, is_shared=False),
151+
# recurrent_state_c: Tensor(shape=torch.Size([10, 1, 128]), device=cpu, dtype=torch.float32, is_shared=False),
152+
# recurrent_state_h: Tensor(shape=torch.Size([10, 1, 128]), device=cpu, dtype=torch.float32, is_shared=False),
153+
# step_count: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False),
154+
# terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
155+
# truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
156+
# batch_size=torch.Size([10]),
157+
# device=cpu,
158+
# is_shared=False)
159+
#
160+
161+
# Notes:
162+
# 1. make_tensordict_primer
163+
#
164+
# Regarding make_tensordict_primer, it creates a TensorDictPrimer object that ensures the policy is aware
165+
# of the supplementary inputs and outputs (recurrent states) during rollout execution.
166+
# This is necessary when using batched environments or parallel data collection, as the recurrent states
167+
# need to be shared across processes and dealt with properly.
168+
#
169+
# In other words, make_tensordict_primer adds the LSTM's hidden states to the environment's specs,
170+
# allowing the environment to properly handle the recurrent states during rollouts. Without it, the policy
171+
# would not be able to use the LSTM's memory buffers correctly, leading to poorly defined behaviors,
172+
# especially in parallel settings.
173+
#
174+
# By adding the TensorDictPrimer to the environment, you ensure that the policy can correctly use the
175+
# LSTM's recurrent states, even when running in parallel or batched environments. This is why
176+
# env.append_transform(lstm.make_tensordict_primer()) is called before creating the policy and rolling it
177+
# out in the environment.
178+
#
179+
# 2. Using the LSTM to process multiple steps at once.
180+
#
181+
# When set_recurrent_mode("recurrent") is used, the LSTM will process the entire input tensordict as a sequence, using
182+
# its recurrent connections to maintain state across time steps. This mode may utilize CuDNN to accelerate the processing
183+
# of the sequence on CUDA devices. The behavior in this mode is akin to torch.nn.LSTM, where the LSTM expects the input
184+
# data to be organized in batches of sequences.
185+
#
186+
# On the other hand, when set_recurrent_mode("sequential") is used, the
187+
# LSTM will process each step in the input tensordict independently, without maintaining any state across time steps. This
188+
# mode makes the LSTM behave similarly to torch.nn.LSTMCell, where each input is treated as a separate, independent
189+
# element.
190+
#
191+
# In the example code, set_recurrent_mode("recurrent") is used to process a tensordict of shape [T], where T
192+
# is the number of steps. This allows the LSTM to use its recurrent connections to maintain state across the entire
193+
# sequence.
194+
#
195+
# In contrast, set_recurrent_mode("sequential") is used to process a single step from the tensordict (i.e.,
196+
# rollout[0]). In this case, the LSTM does not use its recurrent connections, and simply processes the single step as if
197+
# it were an independent input.
198+
199+
with set_recurrent_mode("recurrent"):
200+
# Process a tensordict of shape [T] where T is a number of steps
201+
print(policy(rollout))
202+
203+
with set_recurrent_mode("sequential"):
204+
# Process a tensordict of shape [T] where T is a number of steps
205+
print(policy(rollout[0]))

torchrl/modules/tensordict_module/rnn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1652,8 +1652,8 @@ class set_recurrent_mode(_DecoratorContextManager):
16521652
"""Context manager for setting RNNs recurrent mode.
16531653
16541654
Args:
1655-
mode (bool, "recurrent" or "stateful"): the recurrent mode to be used within the context manager.
1656-
`"recurrent"` leads to `mode=True` and `"stateful"` leads to `mode=False`.
1655+
mode (bool, "recurrent" or "sequential"): the recurrent mode to be used within the context manager.
1656+
`"recurrent"` leads to `mode=True` and `"sequential"` leads to `mode=False`.
16571657
An RNN executed with recurrent_mode "on" assumes that the data comes in time batches, otherwise
16581658
it is assumed that each data element in a tensordict is independent of the others.
16591659
The default value of this context manager is ``True``.

0 commit comments

Comments
 (0)