Skip to content

Commit 1ce25f1

Browse files
author
Vincent Moens
committed
[Feature] Log pbar rate in SOTA implementations
ghstack-source-id: 283cc1b Pull Request resolved: #2662
1 parent 1fc9577 commit 1ce25f1

29 files changed

+168
-151
lines changed

sota-implementations/a2c/a2c_atari.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
189189
with timeit("collecting"):
190190
data = next(c_iter)
191191

192-
log_info = {}
192+
metrics_to_log = {}
193193
frames_in_batch = data.numel()
194194
collected_frames += frames_in_batch * frame_skip
195195
pbar.update(data.numel())
@@ -198,7 +198,7 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
198198
episode_rewards = data["next", "episode_reward"][data["next", "terminated"]]
199199
if len(episode_rewards) > 0:
200200
episode_length = data["next", "step_count"][data["next", "terminated"]]
201-
log_info.update(
201+
metrics_to_log.update(
202202
{
203203
"train/reward": episode_rewards.mean().item(),
204204
"train/episode_length": episode_length.sum().item()
@@ -242,8 +242,8 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
242242
losses = torch.stack(losses).float().mean()
243243

244244
for key, value in losses.items():
245-
log_info.update({f"train/{key}": value.item()})
246-
log_info.update(
245+
metrics_to_log.update({f"train/{key}": value.item()})
246+
metrics_to_log.update(
247247
{
248248
"train/lr": lr * alpha,
249249
}
@@ -259,15 +259,16 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
259259
test_rewards = eval_model(
260260
actor_eval, test_env, num_episodes=cfg.logger.num_test_episodes
261261
)
262-
log_info.update(
262+
metrics_to_log.update(
263263
{
264264
"test/reward": test_rewards.mean(),
265265
}
266266
)
267-
log_info.update(timeit.todict(prefix="time"))
268267

269268
if logger:
270-
for key, value in log_info.items():
269+
metrics_to_log.update(timeit.todict(prefix="time"))
270+
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
271+
for key, value in metrics_to_log.items():
271272
logger.log_scalar(key, value, collected_frames)
272273

273274
collector.shutdown()

sota-implementations/a2c/a2c_mujoco.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def update(batch):
186186
with timeit("collecting"):
187187
data = next(c_iter)
188188

189-
log_info = {}
189+
metrics_to_log = {}
190190
frames_in_batch = data.numel()
191191
collected_frames += frames_in_batch
192192
pbar.update(data.numel())
@@ -195,7 +195,7 @@ def update(batch):
195195
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
196196
if len(episode_rewards) > 0:
197197
episode_length = data["next", "step_count"][data["next", "done"]]
198-
log_info.update(
198+
metrics_to_log.update(
199199
{
200200
"train/reward": episode_rewards.mean().item(),
201201
"train/episode_length": episode_length.sum().item()
@@ -236,8 +236,8 @@ def update(batch):
236236
# Get training losses
237237
losses = torch.stack(losses).float().mean()
238238
for key, value in losses.items():
239-
log_info.update({f"train/{key}": value.item()})
240-
log_info.update(
239+
metrics_to_log.update({f"train/{key}": value.item()})
240+
metrics_to_log.update(
241241
{
242242
"train/lr": alpha * cfg.optim.lr,
243243
}
@@ -253,21 +253,19 @@ def update(batch):
253253
test_rewards = eval_model(
254254
actor, test_env, num_episodes=cfg.logger.num_test_episodes
255255
)
256-
log_info.update(
256+
metrics_to_log.update(
257257
{
258258
"test/reward": test_rewards.mean(),
259259
}
260260
)
261261
actor.train()
262262

263-
log_info.update(timeit.todict(prefix="time"))
264-
265263
if logger:
266-
for key, value in log_info.items():
264+
metrics_to_log.update(timeit.todict(prefix="time"))
265+
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
266+
for key, value in metrics_to_log.items():
267267
logger.log_scalar(key, value, collected_frames)
268268

269-
torch.compiler.cudagraph_mark_step_begin()
270-
271269
collector.shutdown()
272270
if not test_env.is_closed:
273271
test_env.close()

sota-implementations/cql/cql_offline.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def update(data, policy_eval_start, iteration):
172172
)
173173

174174
# log metrics
175-
to_log = {
175+
metrics_to_log = {
176176
"loss": loss.cpu(),
177177
**loss_vals.cpu(),
178178
}
@@ -188,11 +188,12 @@ def update(data, policy_eval_start, iteration):
188188
)
189189
eval_env.apply(dump_video)
190190
eval_reward = eval_td["next", "reward"].sum(1).mean().item()
191-
to_log["evaluation_reward"] = eval_reward
191+
metrics_to_log["evaluation_reward"] = eval_reward
192192

193193
with timeit("log"):
194-
to_log.update(timeit.todict(prefix="time"))
195-
log_metrics(logger, to_log, i)
194+
metrics_to_log.update(timeit.todict(prefix="time"))
195+
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
196+
log_metrics(logger, metrics_to_log, i)
196197

197198
pbar.close()
198199
if not eval_env.is_closed:

sota-implementations/cql/cql_online.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,6 @@ def update(sampled_tensordict):
220220
"loss_alpha_prime"
221221
).mean()
222222
metrics_to_log["train/entropy"] = log_loss_td.get("entropy").mean()
223-
metrics_to_log.update(timeit.todict(prefix="time"))
224223

225224
# Evaluation
226225
with timeit("eval"):
@@ -241,6 +240,8 @@ def update(sampled_tensordict):
241240
eval_env.apply(dump_video)
242241
metrics_to_log["eval/reward"] = eval_reward
243242

243+
metrics_to_log.update(timeit.todict(prefix="time"))
244+
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
244245
log_metrics(logger, metrics_to_log, collected_frames)
245246

246247
collector.shutdown()

sota-implementations/cql/discrete_cql_online.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def update(sampled_tensordict):
179179
sampled_tensordict = sampled_tensordict.to(device)
180180
with timeit("update"):
181181
torch.compiler.cudagraph_mark_step_begin()
182-
loss_dict = update(sampled_tensordict)
182+
loss_dict = update(sampled_tensordict).clone()
183183
tds.append(loss_dict)
184184

185185
# Update priority
@@ -222,9 +222,10 @@ def update(sampled_tensordict):
222222
tds = torch.stack(tds, dim=0).mean()
223223
metrics_to_log["train/q_loss"] = tds["loss_qvalue"]
224224
metrics_to_log["train/cql_loss"] = tds["loss_cql"]
225-
metrics_to_log.update(timeit.todict(prefix="time"))
226225

227226
if logger is not None:
227+
metrics_to_log.update(timeit.todict(prefix="time"))
228+
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
228229
log_metrics(logger, metrics_to_log, collected_frames)
229230

230231
collector.shutdown()

sota-implementations/cql/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def make_offline_replay_buffer(rb_cfg):
185185
dataset_id=rb_cfg.dataset,
186186
split_trajs=False,
187187
batch_size=rb_cfg.batch_size,
188-
sampler=SamplerWithoutReplacement(drop_last=False),
188+
sampler=SamplerWithoutReplacement(drop_last=True),
189189
prefetch=4,
190190
direct_download=True,
191191
)

sota-implementations/crossq/crossq.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,13 +256,14 @@ def update(sampled_tensordict: TensorDict, update_actor: bool):
256256
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
257257
episode_length
258258
)
259-
metrics_to_log.update(timeit.todict(prefix="time"))
260259
if collected_frames >= init_random_frames:
261260
metrics_to_log["train/q_loss"] = tds["loss_qvalue"]
262261
metrics_to_log["train/actor_loss"] = tds["loss_actor"]
263262
metrics_to_log["train/alpha_loss"] = tds["loss_alpha"]
264263

265264
if logger is not None:
265+
metrics_to_log.update(timeit.todict(prefix="time"))
266+
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
266267
log_metrics(logger, metrics_to_log, collected_frames)
267268

268269
collector.shutdown()

sota-implementations/ddpg/ddpg.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,9 +224,10 @@ def update(sampled_tensordict):
224224
eval_env.apply(dump_video)
225225
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
226226
metrics_to_log["eval/reward"] = eval_reward
227-
metrics_to_log.update(timeit.todict(prefix="time"))
228227

229228
if logger is not None:
229+
metrics_to_log.update(timeit.todict(prefix="time"))
230+
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
230231
log_metrics(logger, metrics_to_log, collected_frames)
231232

232233
collector.shutdown()

sota-implementations/decision_transformer/dt.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,9 @@ def main(cfg: "DictConfig"): # noqa: F821
7676
loss_module = make_dt_loss(cfg.loss, actor, device=model_device)
7777

7878
# Create optimizer
79-
transformer_optim, scheduler = make_dt_optimizer(cfg.optim, loss_module)
79+
transformer_optim, scheduler = make_dt_optimizer(
80+
cfg.optim, loss_module, model_device
81+
)
8082

8183
# Create inference policy
8284
inference_policy = DecisionTransformerInferenceWrapper(
@@ -136,7 +138,7 @@ def update(data: TensorDict) -> TensorDict:
136138
loss_vals = update(data)
137139
scheduler.step()
138140
# Log metrics
139-
to_log = {"train/loss": loss_vals["loss"]}
141+
metrics_to_log = {"train/loss": loss_vals["loss"]}
140142

141143
# Evaluation
142144
with set_exploration_type(
@@ -149,13 +151,14 @@ def update(data: TensorDict) -> TensorDict:
149151
auto_cast_to_device=True,
150152
)
151153
test_env.apply(dump_video)
152-
to_log["eval/reward"] = (
154+
metrics_to_log["eval/reward"] = (
153155
eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
154156
)
155-
to_log.update(timeit.todict(prefix="time"))
156157

157158
if logger is not None:
158-
log_metrics(logger, to_log, i)
159+
metrics_to_log.update(timeit.todict(prefix="time"))
160+
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
161+
log_metrics(logger, metrics_to_log, i)
159162

160163
pbar.close()
161164
if not test_env.is_closed:

sota-implementations/decision_transformer/online_dt.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def update(data):
143143
scheduler.step()
144144

145145
# Log metrics
146-
to_log = {
146+
metrics_to_log = {
147147
"train/loss_log_likelihood": loss_vals["loss_log_likelihood"],
148148
"train/loss_entropy": loss_vals["loss_entropy"],
149149
"train/loss_alpha": loss_vals["loss_alpha"],
@@ -165,14 +165,14 @@ def update(data):
165165
)
166166
test_env.apply(dump_video)
167167
inference_policy.train()
168-
to_log["eval/reward"] = (
168+
metrics_to_log["eval/reward"] = (
169169
eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
170170
)
171171

172-
to_log.update(timeit.todict(prefix="time"))
173-
174172
if logger is not None:
175-
log_metrics(logger, to_log, i)
173+
metrics_to_log.update(timeit.todict(prefix="time"))
174+
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
175+
log_metrics(logger, metrics_to_log, i)
176176

177177
pbar.close()
178178
if not test_env.is_closed:

sota-implementations/decision_transformer/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -511,10 +511,10 @@ def make_odt_optimizer(optim_cfg, loss_module):
511511
return dt_optimizer, log_temp_optimizer, scheduler
512512

513513

514-
def make_dt_optimizer(optim_cfg, loss_module):
514+
def make_dt_optimizer(optim_cfg, loss_module, device):
515515
dt_optimizer = torch.optim.Adam(
516516
loss_module.actor_network_params.flatten_keys().values(),
517-
lr=torch.as_tensor(optim_cfg.lr),
517+
lr=torch.tensor(optim_cfg.lr, device=device),
518518
weight_decay=optim_cfg.weight_decay,
519519
eps=1.0e-8,
520520
)

sota-implementations/discrete_sac/discrete_sac.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,9 @@ def update(sampled_tensordict):
227227
eval_env.apply(dump_video)
228228
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
229229
metrics_to_log["eval/reward"] = eval_reward
230-
metrics_to_log.update(timeit.todict(prefix="time"))
231230
if logger is not None:
231+
metrics_to_log.update(timeit.todict(prefix="time"))
232+
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
232233
log_metrics(logger, metrics_to_log, collected_frames)
233234

234235
collector.shutdown()

sota-implementations/dqn/config_atari.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ env:
77
# collector
88
collector:
99
total_frames: 40_000_100
10-
frames_per_batch: 16
10+
frames_per_batch: 1600
1111
eps_start: 1.0
1212
eps_end: 0.01
1313
annealing_frames: 4_000_000
@@ -38,9 +38,9 @@ optim:
3838
loss:
3939
gamma: 0.99
4040
hard_update_freq: 10_000
41-
num_updates: 1
41+
num_updates: 100
4242

4343
compile:
4444
compile: False
45-
compile_mode:
45+
compile_mode: default
4646
cudagraphs: False

sota-implementations/dqn/config_cartpole.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ env:
77
# collector
88
collector:
99
total_frames: 500_100
10-
frames_per_batch: 10
10+
frames_per_batch: 1000
1111
eps_start: 1.0
1212
eps_end: 0.05
1313
annealing_frames: 250_000
@@ -37,7 +37,7 @@ optim:
3737
loss:
3838
gamma: 0.99
3939
hard_update_freq: 50
40-
num_updates: 1
40+
num_updates: 100
4141

4242
compile:
4343
compile: False

0 commit comments

Comments
 (0)