Skip to content

Commit cbb8a54

Browse files
committed
[Algorithm] TD3 fast
ghstack-source-id: b31466e Pull Request resolved: #2389
1 parent ecc5e00 commit cbb8a54

File tree

4 files changed

+365
-33
lines changed

4 files changed

+365
-33
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# task and env
2+
env:
3+
name: HalfCheetah-v4 # Use v4 to get rid of mujoco-py dependency
4+
task: ""
5+
library: gymnasium
6+
seed: 42
7+
max_episode_steps: 1000
8+
9+
# collector
10+
collector:
11+
total_frames: 1000000
12+
init_random_frames: 25_000
13+
init_env_steps: 1000
14+
frames_per_batch: 1000
15+
reset_at_each_iter: False
16+
device: cpu
17+
env_per_collector: 1
18+
num_workers: 8
19+
20+
# replay buffer
21+
replay_buffer:
22+
prb: 0 # use prioritized experience replay
23+
size: 1000000
24+
scratch_dir: null
25+
device: null
26+
27+
# optim
28+
optim:
29+
utd_ratio: 1.0
30+
gamma: 0.99
31+
loss_function: l2
32+
lr: 3.0e-4
33+
weight_decay: 0.0
34+
adam_eps: 1e-4
35+
batch_size: 256
36+
target_update_polyak: 0.995
37+
policy_update_delay: 2
38+
policy_noise: 0.2
39+
noise_clip: 0.5
40+
41+
# network
42+
network:
43+
hidden_sizes: [256, 256]
44+
activation: relu
45+
device: null
46+
47+
# logging
48+
logger:
49+
backend: wandb
50+
project_name: torchrl_example_td3
51+
group_name: null
52+
exp_name: ${env.name}_TD3
53+
mode: online
54+
eval_iter: 25000
55+
video: False

sota-implementations/td3/td3-fast.py

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
"""TD3 Example.
6+
7+
This is a simple self-contained example of a TD3 training script.
8+
9+
It supports state environments like MuJoCo.
10+
11+
The helper functions are coded in the utils.py associated with this script.
12+
"""
13+
import time
14+
15+
import hydra
16+
import numpy as np
17+
import torch
18+
import torch.cuda
19+
import tqdm
20+
from torchrl._utils import logger as torchrl_logger
21+
from torchrl.data.utils import CloudpickleWrapper
22+
23+
from torchrl.envs.utils import ExplorationType, set_exploration_type
24+
25+
from torchrl.record.loggers import generate_exp_name, get_logger
26+
from utils import (
27+
log_metrics,
28+
make_async_collector,
29+
make_environment,
30+
make_loss_module,
31+
make_optimizer,
32+
make_replay_buffer,
33+
make_simple_environment,
34+
make_td3_agent,
35+
)
36+
37+
38+
@hydra.main(version_base="1.1", config_path="", config_name="config-fast")
39+
def main(cfg: "DictConfig"): # noqa: F821
40+
device = cfg.network.device
41+
if device in ("", None):
42+
if torch.cuda.is_available():
43+
device = torch.device("cuda:0")
44+
else:
45+
device = torch.device("cpu")
46+
device = torch.device(device)
47+
48+
# Create logger
49+
exp_name = generate_exp_name("TD3", cfg.logger.exp_name)
50+
logger = None
51+
if cfg.logger.backend:
52+
logger = get_logger(
53+
logger_type=cfg.logger.backend,
54+
logger_name="td3_logging",
55+
experiment_name=exp_name,
56+
wandb_kwargs={
57+
"mode": cfg.logger.mode,
58+
"config": dict(cfg),
59+
"project": cfg.logger.project_name,
60+
"group": cfg.logger.group_name,
61+
},
62+
)
63+
64+
# Set seeds
65+
torch.manual_seed(cfg.env.seed)
66+
np.random.seed(cfg.env.seed)
67+
68+
# Create environments
69+
train_env, eval_env = make_environment(cfg, logger=logger)
70+
71+
# Create agent
72+
model, exploration_policy = make_td3_agent(cfg, train_env, eval_env, device)
73+
74+
# Create TD3 loss
75+
loss_module, target_net_updater = make_loss_module(cfg, model)
76+
77+
# Create replay buffer
78+
replay_buffer = make_replay_buffer(
79+
batch_size=cfg.optim.batch_size,
80+
prb=cfg.replay_buffer.prb,
81+
buffer_size=cfg.replay_buffer.size,
82+
scratch_dir=cfg.replay_buffer.scratch_dir,
83+
device=cfg.replay_buffer.device if cfg.replay_buffer.device else device,
84+
prefetch=0,
85+
mmap=False,
86+
)
87+
reshape = CloudpickleWrapper(lambda td: td.reshape(-1))
88+
replay_buffer.append_transform(reshape, invert=True)
89+
90+
# Create off-policy collector
91+
envname = cfg.env.name
92+
task = cfg.env.task
93+
library = cfg.env.library
94+
seed = cfg.env.seed
95+
max_episode_steps = cfg.env.max_episode_steps
96+
collector = make_async_collector(
97+
cfg,
98+
lambda: make_simple_environment(
99+
envname, task, library, seed, max_episode_steps
100+
),
101+
exploration_policy,
102+
replay_buffer,
103+
)
104+
105+
# Create optimizers
106+
optimizer_actor, optimizer_critic = make_optimizer(cfg, loss_module)
107+
108+
# Main loop
109+
start_time = time.time()
110+
collected_frames = 0
111+
pbar = tqdm.tqdm(total=cfg.collector.total_frames)
112+
113+
init_random_frames = cfg.collector.init_random_frames
114+
num_updates = int(
115+
max(1, cfg.collector.env_per_collector)
116+
* cfg.collector.frames_per_batch
117+
* cfg.optim.utd_ratio
118+
)
119+
delayed_updates = cfg.optim.policy_update_delay
120+
prb = cfg.replay_buffer.prb
121+
update_counter = 0
122+
123+
sampling_start = time.time()
124+
current_frames = cfg.collector.frames_per_batch
125+
update_actor = False
126+
127+
test_env = make_simple_environment(envname, task, library, seed, max_episode_steps)
128+
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
129+
reward = test_env.rollout(10_000, exploration_policy)["next", "reward"].mean()
130+
print(f"reward before training: {reward: 4.4f}")
131+
132+
for _ in collector:
133+
sampling_time = time.time() - sampling_start
134+
exploration_policy[1].step(current_frames)
135+
136+
# Update weights of the inference policy
137+
collector.update_policy_weights_()
138+
139+
pbar.update(current_frames)
140+
141+
# Add to replay buffer
142+
collected_frames += current_frames
143+
144+
# Optimization steps
145+
training_start = time.time()
146+
loss_module.value_loss = torch.compile(
147+
loss_module.value_loss, mode="reduce-overhead"
148+
)
149+
loss_module.actor_loss = torch.compile(
150+
loss_module.actor_loss, mode="reduce-overhead"
151+
)
152+
153+
if collected_frames >= init_random_frames:
154+
(
155+
actor_losses,
156+
q_losses,
157+
) = ([], [])
158+
for _ in range(num_updates):
159+
160+
# Update actor every delayed_updates
161+
update_counter += 1
162+
update_actor = update_counter % delayed_updates == 0
163+
164+
# Sample from replay buffer
165+
sampled_tensordict = replay_buffer.sample()
166+
if sampled_tensordict.device != device:
167+
sampled_tensordict = sampled_tensordict.to(
168+
device, non_blocking=True
169+
)
170+
else:
171+
sampled_tensordict = sampled_tensordict.clone()
172+
173+
# Compute loss
174+
q_loss, *_ = loss_module.value_loss(sampled_tensordict)
175+
176+
# Update critic
177+
optimizer_critic.zero_grad()
178+
q_loss.backward()
179+
optimizer_critic.step()
180+
q_losses.append(q_loss.item())
181+
182+
# Update actor
183+
if update_actor:
184+
actor_loss, *_ = loss_module.actor_loss(sampled_tensordict)
185+
optimizer_actor.zero_grad()
186+
actor_loss.backward()
187+
optimizer_actor.step()
188+
189+
actor_losses.append(actor_loss.item())
190+
191+
# Update target params
192+
target_net_updater.step()
193+
194+
# Update priority
195+
if prb:
196+
replay_buffer.update_priority(sampled_tensordict)
197+
198+
training_time = time.time() - training_start
199+
200+
# Logging
201+
metrics_to_log = {}
202+
if collected_frames >= init_random_frames:
203+
metrics_to_log["train/q_loss"] = np.mean(q_losses)
204+
if update_actor:
205+
metrics_to_log["train/a_loss"] = np.mean(actor_losses)
206+
metrics_to_log["train/sampling_time"] = sampling_time
207+
metrics_to_log["train/training_time"] = training_time
208+
209+
if logger is not None:
210+
log_metrics(logger, metrics_to_log, collected_frames)
211+
sampling_start = time.time()
212+
213+
collector.shutdown()
214+
if not eval_env.is_closed:
215+
eval_env.close()
216+
if not train_env.is_closed:
217+
train_env.close()
218+
219+
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
220+
reward = test_env.rollout(10_000, exploration_policy)["next", "reward"].mean()
221+
print(f"reward before training: {reward: 4.4f}")
222+
test_env.close()
223+
224+
end_time = time.time()
225+
execution_time = end_time - start_time
226+
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
227+
228+
229+
if __name__ == "__main__":
230+
main()

0 commit comments

Comments
 (0)