|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# |
| 3 | +# This source code is licensed under the MIT license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | + |
| 7 | +""" |
| 8 | +This code exemplifies how an actor that uses a RNN backbone can be built. |
| 9 | +
|
| 10 | +It is based on snippets from the DQN with RNN tutorial. |
| 11 | +
|
| 12 | +There are two main APIs to be aware of when using RNNs, and dedicated notes regarding these can be found at the end |
| 13 | +of this example: the `set_recurrent_mode` context manager, and the `make_tensordict_primer` method. |
| 14 | +
|
| 15 | +""" |
| 16 | +from collections import OrderedDict |
| 17 | + |
| 18 | +import torch |
| 19 | +from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq |
| 20 | +from torch import nn |
| 21 | + |
| 22 | +from torchrl.envs import ( |
| 23 | + Compose, |
| 24 | + GrayScale, |
| 25 | + GymEnv, |
| 26 | + InitTracker, |
| 27 | + ObservationNorm, |
| 28 | + Resize, |
| 29 | + RewardScaling, |
| 30 | + StepCounter, |
| 31 | + ToTensorImage, |
| 32 | + TransformedEnv, |
| 33 | +) |
| 34 | +from torchrl.modules import ConvNet, LSTMModule, MLP, QValueModule, set_recurrent_mode |
| 35 | + |
| 36 | +# Define the device to use for computations (GPU if available, otherwise CPU) |
| 37 | +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| 38 | + |
| 39 | +# Create a transformed environment using the CartPole-v1 gym environment |
| 40 | +env = TransformedEnv( |
| 41 | + GymEnv("CartPole-v1", from_pixels=True, device=device), |
| 42 | + # Apply a series of transformations to the environment: |
| 43 | + # 1. Convert observations to tensor images |
| 44 | + # 2. Convert images to grayscale |
| 45 | + # 3. Resize images to 84x84 pixels |
| 46 | + # 4. Keep track of the step count |
| 47 | + # 5. Initialize a tracker for the environment |
| 48 | + # 6. Scale rewards by a factor of 0.1 |
| 49 | + # 7. Normalize observations to have zero mean and unit variance (we'll adapt that dynamically later) |
| 50 | + Compose( |
| 51 | + ToTensorImage(), |
| 52 | + GrayScale(), |
| 53 | + Resize(84, 84), |
| 54 | + StepCounter(), |
| 55 | + InitTracker(), |
| 56 | + RewardScaling(loc=0.0, scale=0.1), |
| 57 | + ObservationNorm(standard_normal=True, in_keys=["pixels"]), |
| 58 | + ), |
| 59 | +) |
| 60 | + |
| 61 | +# Initialize the normalization statistics for the observation norm transform |
| 62 | +env.transform[-1].init_stats(1000, reduce_dim=[0, 1, 2], cat_dim=0, keep_dims=[0]) |
| 63 | + |
| 64 | +# Reset the environment to get an initial observation |
| 65 | +td = env.reset() |
| 66 | + |
| 67 | +# Define a feature extractor module that takes pixel observations as input |
| 68 | +# and outputs an embedding vector |
| 69 | +feature = Mod( |
| 70 | + ConvNet( |
| 71 | + num_cells=[32, 32, 64], |
| 72 | + squeeze_output=True, |
| 73 | + aggregator_class=nn.AdaptiveAvgPool2d, |
| 74 | + aggregator_kwargs={"output_size": (1, 1)}, |
| 75 | + device=device, |
| 76 | + ), |
| 77 | + in_keys=["pixels"], |
| 78 | + out_keys=["embed"], |
| 79 | +) |
| 80 | + |
| 81 | +# Get the shape of the embedding vector output by the feature extractor |
| 82 | +with torch.no_grad(): |
| 83 | + n_cells = feature(env.reset())["embed"].shape[-1] |
| 84 | + |
| 85 | +# Define an LSTM module that takes the embedding vector as input and outputs |
| 86 | +# a new embedding vector |
| 87 | +lstm = LSTMModule( |
| 88 | + input_size=n_cells, |
| 89 | + hidden_size=128, |
| 90 | + device=device, |
| 91 | + in_key="embed", |
| 92 | + out_key="embed", |
| 93 | +) |
| 94 | + |
| 95 | +# Define a multi-layer perceptron (MLP) module that takes the LSTM output as |
| 96 | +# input and outputs action values |
| 97 | +mlp = MLP( |
| 98 | + out_features=2, |
| 99 | + num_cells=[ |
| 100 | + 64, |
| 101 | + ], |
| 102 | + device=device, |
| 103 | +) |
| 104 | + |
| 105 | +# Initialize the bias of the last layer of the MLP to zero |
| 106 | +mlp[-1].bias.data.fill_(0.0) |
| 107 | + |
| 108 | +# Wrap the MLP in a TensorDictModule to handle input/output keys |
| 109 | +mlp = Mod(mlp, in_keys=["embed"], out_keys=["action_value"]) |
| 110 | + |
| 111 | +# Define a Q-value module that computes the Q-value of the current state |
| 112 | +qval = QValueModule(action_space=None, spec=env.action_spec) |
| 113 | + |
| 114 | +# Add a TensorDictPrimer to the environment to ensure that the policy is aware |
| 115 | +# of the supplementary inputs and outputs (recurrent states) during rollout execution |
| 116 | +# This is necessary when using batched environments or parallel data collection |
| 117 | +env.append_transform(lstm.make_tensordict_primer()) |
| 118 | + |
| 119 | +# Create a sequential module that combines the feature extractor, LSTM, MLP, and Q-value modules |
| 120 | +policy = Seq(OrderedDict(feature=feature, lstm=lstm, mlp=mlp, qval=qval)) |
| 121 | + |
| 122 | +# Roll out the policy in the environment for 100 steps |
| 123 | +rollout = env.rollout(100, policy) |
| 124 | +print(rollout) |
| 125 | + |
| 126 | +# Print result: |
| 127 | +# |
| 128 | +# TensorDict( |
| 129 | +# fields={ |
| 130 | +# action: Tensor(shape=torch.Size([10, 2]), device=cpu, dtype=torch.int64, is_shared=False), |
| 131 | +# action_value: Tensor(shape=torch.Size([10, 2]), device=cpu, dtype=torch.float32, is_shared=False), |
| 132 | +# chosen_action_value: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False), |
| 133 | +# done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), |
| 134 | +# embed: Tensor(shape=torch.Size([10, 128]), device=cpu, dtype=torch.float32, is_shared=False), |
| 135 | +# is_init: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), |
| 136 | +# next: TensorDict( |
| 137 | +# fields={ |
| 138 | +# done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), |
| 139 | +# is_init: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), |
| 140 | +# pixels: Tensor(shape=torch.Size([10, 1, 84, 84]), device=cpu, dtype=torch.float32, is_shared=False), |
| 141 | +# recurrent_state_c: Tensor(shape=torch.Size([10, 1, 128]), device=cpu, dtype=torch.float32, is_shared=False), |
| 142 | +# recurrent_state_h: Tensor(shape=torch.Size([10, 1, 128]), device=cpu, dtype=torch.float32, is_shared=False), |
| 143 | +# reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False), |
| 144 | +# step_count: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False), |
| 145 | +# terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), |
| 146 | +# truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, |
| 147 | +# batch_size=torch.Size([10]), |
| 148 | +# device=cpu, |
| 149 | +# is_shared=False), |
| 150 | +# pixels: Tensor(shape=torch.Size([10, 1, 84, 84]), device=cpu, dtype=torch.float32, is_shared=False), |
| 151 | +# recurrent_state_c: Tensor(shape=torch.Size([10, 1, 128]), device=cpu, dtype=torch.float32, is_shared=False), |
| 152 | +# recurrent_state_h: Tensor(shape=torch.Size([10, 1, 128]), device=cpu, dtype=torch.float32, is_shared=False), |
| 153 | +# step_count: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False), |
| 154 | +# terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), |
| 155 | +# truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, |
| 156 | +# batch_size=torch.Size([10]), |
| 157 | +# device=cpu, |
| 158 | +# is_shared=False) |
| 159 | +# |
| 160 | + |
| 161 | +# Notes: |
| 162 | +# 1. make_tensordict_primer |
| 163 | +# |
| 164 | +# Regarding make_tensordict_primer, it creates a TensorDictPrimer object that ensures the policy is aware |
| 165 | +# of the supplementary inputs and outputs (recurrent states) during rollout execution. |
| 166 | +# This is necessary when using batched environments or parallel data collection, as the recurrent states |
| 167 | +# need to be shared across processes and dealt with properly. |
| 168 | +# |
| 169 | +# In other words, make_tensordict_primer adds the LSTM's hidden states to the environment's specs, |
| 170 | +# allowing the environment to properly handle the recurrent states during rollouts. Without it, the policy |
| 171 | +# would not be able to use the LSTM's memory buffers correctly, leading to poorly defined behaviors, |
| 172 | +# especially in parallel settings. |
| 173 | +# |
| 174 | +# By adding the TensorDictPrimer to the environment, you ensure that the policy can correctly use the |
| 175 | +# LSTM's recurrent states, even when running in parallel or batched environments. This is why |
| 176 | +# env.append_transform(lstm.make_tensordict_primer()) is called before creating the policy and rolling it |
| 177 | +# out in the environment. |
| 178 | +# |
| 179 | +# 2. Using the LSTM to process multiple steps at once. |
| 180 | +# |
| 181 | +# When set_recurrent_mode("recurrent") is used, the LSTM will process the entire input tensordict as a sequence, using |
| 182 | +# its recurrent connections to maintain state across time steps. This mode may utilize CuDNN to accelerate the processing |
| 183 | +# of the sequence on CUDA devices. The behavior in this mode is akin to torch.nn.LSTM, where the LSTM expects the input |
| 184 | +# data to be organized in batches of sequences. |
| 185 | +# |
| 186 | +# On the other hand, when set_recurrent_mode("sequential") is used, the |
| 187 | +# LSTM will process each step in the input tensordict independently, without maintaining any state across time steps. This |
| 188 | +# mode makes the LSTM behave similarly to torch.nn.LSTMCell, where each input is treated as a separate, independent |
| 189 | +# element. |
| 190 | +# |
| 191 | +# In the example code, set_recurrent_mode("recurrent") is used to process a tensordict of shape [T], where T |
| 192 | +# is the number of steps. This allows the LSTM to use its recurrent connections to maintain state across the entire |
| 193 | +# sequence. |
| 194 | +# |
| 195 | +# In contrast, set_recurrent_mode("sequential") is used to process a single step from the tensordict (i.e., |
| 196 | +# rollout[0]). In this case, the LSTM does not use its recurrent connections, and simply processes the single step as if |
| 197 | +# it were an independent input. |
| 198 | + |
| 199 | +with set_recurrent_mode("recurrent"): |
| 200 | + # Process a tensordict of shape [T] where T is a number of steps |
| 201 | + print(policy(rollout)) |
| 202 | + |
| 203 | +with set_recurrent_mode("sequential"): |
| 204 | + # Process a tensordict of shape [T] where T is a number of steps |
| 205 | + print(policy(rollout[0])) |
0 commit comments