Skip to content

Commit 6d4ea6d

Browse files
committed
[Algorithm] TD3 fast
ghstack-source-id: 6385b70 Pull Request resolved: #2389
1 parent 6a1ec81 commit 6d4ea6d

File tree

4 files changed

+381
-33
lines changed

4 files changed

+381
-33
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# task and env
2+
env:
3+
name: HalfCheetah-v4 # Use v4 to get rid of mujoco-py dependency
4+
task: ""
5+
library: gymnasium
6+
seed: 42
7+
max_episode_steps: 1000
8+
9+
# collector
10+
collector:
11+
total_frames: 1000000
12+
init_random_frames: 25_000
13+
init_env_steps: 1000
14+
frames_per_batch: 1000
15+
reset_at_each_iter: False
16+
device: cpu
17+
env_per_collector: 1
18+
num_workers: 8
19+
20+
# replay buffer
21+
replay_buffer:
22+
prb: 0 # use prioritized experience replay
23+
size: 1000000
24+
scratch_dir: null
25+
device: null
26+
27+
# optim
28+
optim:
29+
utd_ratio: 1.0
30+
gamma: 0.99
31+
loss_function: l2
32+
lr: 3.0e-4
33+
weight_decay: 0.0
34+
adam_eps: 1e-4
35+
batch_size: 256
36+
target_update_polyak: 0.995
37+
policy_update_delay: 2
38+
policy_noise: 0.2
39+
noise_clip: 0.5
40+
41+
# network
42+
network:
43+
hidden_sizes: [256, 256]
44+
activation: relu
45+
device: null
46+
47+
# logging
48+
logger:
49+
backend: wandb
50+
project_name: torchrl_example_td3
51+
group_name: null
52+
exp_name: ${env.name}_TD3
53+
mode: online
54+
eval_iter: 25000
55+
video: False

sota-implementations/td3/td3-fast.py

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
"""TD3 Example.
6+
7+
This is a simple self-contained example of a TD3 training script.
8+
9+
It supports state environments like MuJoCo.
10+
11+
The helper functions are coded in the utils.py associated with this script.
12+
"""
13+
import time
14+
15+
import hydra
16+
import numpy as np
17+
import torch
18+
import torch.cuda
19+
import tqdm
20+
from torchrl._utils import logger as torchrl_logger
21+
from torchrl.data.utils import CloudpickleWrapper
22+
23+
from torchrl.envs.utils import ExplorationType, set_exploration_type
24+
25+
from torchrl.record.loggers import generate_exp_name, get_logger
26+
from utils import (
27+
log_metrics,
28+
make_async_collector,
29+
make_environment,
30+
make_loss_module,
31+
make_optimizer,
32+
make_replay_buffer,
33+
make_simple_environment,
34+
make_td3_agent,
35+
)
36+
37+
38+
@hydra.main(version_base="1.1", config_path="", config_name="config-fast")
39+
def main(cfg: "DictConfig"): # noqa: F821
40+
device = cfg.network.device
41+
if device in ("", None):
42+
if torch.cuda.is_available():
43+
device = torch.device("cuda:0")
44+
else:
45+
device = torch.device("cpu")
46+
device = torch.device(device)
47+
48+
# Create logger
49+
exp_name = generate_exp_name("TD3", cfg.logger.exp_name)
50+
logger = None
51+
if cfg.logger.backend:
52+
logger = get_logger(
53+
logger_type=cfg.logger.backend,
54+
logger_name="td3_logging",
55+
experiment_name=exp_name,
56+
wandb_kwargs={
57+
"mode": cfg.logger.mode,
58+
"config": dict(cfg),
59+
"project": cfg.logger.project_name,
60+
"group": cfg.logger.group_name,
61+
},
62+
)
63+
64+
# Set seeds
65+
torch.manual_seed(cfg.env.seed)
66+
np.random.seed(cfg.env.seed)
67+
68+
# Create environments
69+
train_env, eval_env = make_environment(cfg, logger=logger)
70+
71+
# Create agent
72+
model, exploration_policy = make_td3_agent(cfg, train_env, eval_env, device)
73+
74+
# Create TD3 loss
75+
loss_module, target_net_updater = make_loss_module(cfg, model)
76+
77+
# Create replay buffer
78+
replay_buffer = make_replay_buffer(
79+
batch_size=cfg.optim.batch_size,
80+
prb=cfg.replay_buffer.prb,
81+
buffer_size=cfg.replay_buffer.size,
82+
scratch_dir=cfg.replay_buffer.scratch_dir,
83+
device=cfg.replay_buffer.device if cfg.replay_buffer.device else device,
84+
prefetch=0,
85+
mmap=False,
86+
)
87+
reshape = CloudpickleWrapper(lambda td: td.reshape(-1))
88+
replay_buffer.append_transform(reshape, invert=True)
89+
90+
# Create off-policy collector
91+
envname = cfg.env.name
92+
task = cfg.env.task
93+
library = cfg.env.library
94+
seed = cfg.env.seed
95+
max_episode_steps = cfg.env.max_episode_steps
96+
collector = make_async_collector(
97+
cfg,
98+
lambda: make_simple_environment(
99+
envname, task, library, seed, max_episode_steps
100+
),
101+
exploration_policy,
102+
replay_buffer,
103+
)
104+
105+
# Create optimizers
106+
optimizer_actor, optimizer_critic = make_optimizer(cfg, loss_module)
107+
108+
# Main loop
109+
start_time = time.time()
110+
collected_frames = 0
111+
pbar = tqdm.tqdm(total=cfg.collector.total_frames)
112+
113+
init_random_frames = cfg.collector.init_random_frames
114+
num_updates = int(
115+
max(1, cfg.collector.env_per_collector)
116+
* cfg.collector.frames_per_batch
117+
* cfg.optim.utd_ratio
118+
)
119+
delayed_updates = cfg.optim.policy_update_delay
120+
prb = cfg.replay_buffer.prb
121+
update_counter = 0
122+
123+
sampling_start = time.time()
124+
current_frames = cfg.collector.frames_per_batch
125+
update_actor = False
126+
127+
test_env = make_simple_environment(envname, task, library, seed, max_episode_steps)
128+
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
129+
reward = test_env.rollout(10_000, exploration_policy)["next", "reward"].mean()
130+
print(f"reward before training: {reward: 4.4f}")
131+
132+
# loss_module.value_loss = torch.compile(
133+
# loss_module.value_loss, mode="reduce-overhead"
134+
# )
135+
# loss_module.actor_loss = torch.compile(
136+
# loss_module.actor_loss, mode="reduce-overhead"
137+
# )
138+
139+
def train_update(sampled_tensordict):
140+
# Compute loss
141+
q_loss, *_ = loss_module.value_loss(sampled_tensordict)
142+
143+
# Update critic
144+
optimizer_critic.zero_grad()
145+
q_loss.backward()
146+
optimizer_critic.step()
147+
q_losses.append(q_loss.item())
148+
149+
# Update actor
150+
if update_actor:
151+
actor_loss, *_ = loss_module.actor_loss(sampled_tensordict)
152+
optimizer_actor.zero_grad()
153+
actor_loss.backward()
154+
optimizer_actor.step()
155+
156+
actor_losses.append(actor_loss.item())
157+
158+
# Update target params
159+
target_net_updater.step()
160+
161+
train_update_cuda = None
162+
g = torch.cuda.CUDAGraph()
163+
164+
for _ in collector:
165+
sampling_time = time.time() - sampling_start
166+
exploration_policy[1].step(current_frames)
167+
168+
# Update weights of the inference policy
169+
collector.update_policy_weights_()
170+
171+
pbar.update(current_frames)
172+
173+
# Add to replay buffer
174+
collected_frames += current_frames
175+
176+
# Optimization steps
177+
training_start = time.time()
178+
179+
if collected_frames >= init_random_frames:
180+
(
181+
actor_losses,
182+
q_losses,
183+
) = ([], [])
184+
for _ in range(num_updates):
185+
186+
# Update actor every delayed_updates
187+
update_counter += 1
188+
update_actor = update_counter % delayed_updates == 0
189+
190+
# Sample from replay buffer
191+
sampled_tensordict = replay_buffer.sample()
192+
if sampled_tensordict.device != device:
193+
sampled_tensordict = sampled_tensordict.to(
194+
device, non_blocking=True
195+
)
196+
else:
197+
sampled_tensordict = sampled_tensordict.clone()
198+
199+
if train_update_cuda is None:
200+
static_sample = sampled_tensordict
201+
with torch.cuda.graph(g):
202+
train_update(static_sample)
203+
204+
def train_update_cuda(x):
205+
static_sample.copy_(x)
206+
g.replay()
207+
else:
208+
train_update_cuda(sampled_tensordict)
209+
210+
# Update priority
211+
if prb:
212+
replay_buffer.update_priority(sampled_tensordict)
213+
214+
training_time = time.time() - training_start
215+
216+
# Logging
217+
metrics_to_log = {}
218+
if collected_frames >= init_random_frames:
219+
metrics_to_log["train/q_loss"] = np.mean(q_losses)
220+
if update_actor:
221+
metrics_to_log["train/a_loss"] = np.mean(actor_losses)
222+
metrics_to_log["train/sampling_time"] = sampling_time
223+
metrics_to_log["train/training_time"] = training_time
224+
225+
if logger is not None:
226+
log_metrics(logger, metrics_to_log, collected_frames)
227+
sampling_start = time.time()
228+
229+
collector.shutdown()
230+
if not eval_env.is_closed:
231+
eval_env.close()
232+
if not train_env.is_closed:
233+
train_env.close()
234+
235+
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
236+
reward = test_env.rollout(10_000, exploration_policy)["next", "reward"].mean()
237+
print(f"reward before training: {reward: 4.4f}")
238+
test_env.close()
239+
240+
end_time = time.time()
241+
execution_time = end_time - start_time
242+
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
243+
244+
245+
if __name__ == "__main__":
246+
main()

0 commit comments

Comments
 (0)