Skip to content

Commit 256a700

Browse files
committed
[Feature] Make PPO compatible with composite actions and log-probs
ghstack-source-id: c41718e Pull Request resolved: #2665
1 parent dc25a55 commit 256a700

File tree

25 files changed

+829
-253
lines changed

25 files changed

+829
-253
lines changed

.github/unittest/linux_sota/scripts/test_sota.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -188,19 +188,6 @@
188188
ppo.collector.frames_per_batch=16 \
189189
logger.mode=offline \
190190
logger.backend=
191-
""",
192-
"dreamer": """python sota-implementations/dreamer/dreamer.py \
193-
collector.total_frames=600 \
194-
collector.init_random_frames=10 \
195-
collector.frames_per_batch=200 \
196-
env.n_parallel_envs=1 \
197-
optimization.optim_steps_per_batch=1 \
198-
logger.video=False \
199-
logger.backend=csv \
200-
replay_buffer.buffer_size=120 \
201-
replay_buffer.batch_size=24 \
202-
replay_buffer.batch_length=12 \
203-
networks.rssm_hidden_dim=17
204191
""",
205192
"ddpg-single": """python sota-implementations/ddpg/ddpg.py \
206193
collector.total_frames=48 \
@@ -289,6 +276,19 @@
289276
logger.backend=
290277
""",
291278
"bandits": """python sota-implementations/bandits/dqn.py --n_steps=100
279+
""",
280+
"dreamer": """python sota-implementations/dreamer/dreamer.py \
281+
collector.total_frames=600 \
282+
collector.init_random_frames=10 \
283+
collector.frames_per_batch=200 \
284+
env.n_parallel_envs=1 \
285+
optimization.optim_steps_per_batch=1 \
286+
logger.video=False \
287+
logger.backend=csv \
288+
replay_buffer.buffer_size=120 \
289+
replay_buffer.batch_size=24 \
290+
replay_buffer.batch_length=12 \
291+
networks.rssm_hidden_dim=17
292292
""",
293293
}
294294

examples/agents/composite_actor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,9 @@ def forward(self, x):
5050
data = TensorDict({"x": torch.rand(10)}, [])
5151
module(data)
5252
print(actor(data))
53+
54+
55+
# TODO:
56+
# 1. Use ("action", "action0") + ("action", "action1") vs ("agent0", "action") + ("agent1", "action")
57+
# 2. Must multi-head require an action_key to be a list of keys (I guess so)
58+
# 3. Using maps in the Actor

examples/agents/composite_ppo.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""
7+
Multi-head Agent and PPO Loss
8+
=============================
9+
This example demonstrates how to use TorchRL to create a multi-head agent with three separate distributions
10+
(Gamma, Kumaraswamy, and Mixture) and train it using Proximal Policy Optimization (PPO) losses.
11+
12+
Step-by-step Explanation
13+
------------------------
14+
15+
1. **Setting Composite Log-Probabilities**:
16+
- To use composite (=multi-head0 distributions with PPO (or any other algorithm that relies on probability distributions like SAC
17+
or A2C), you must call `set_composite_lp_aggregate(False).set()`. Not calling this will result in errors during
18+
execution of your script.
19+
- From torchrl and tensordict v0.9, this will be the default behavior. Not doing this will result in
20+
`CompositeDistribution` aggregating the log-probs, which may lead to incorrect log-probabilities.
21+
- Note that `set_composite_lp_aggregate(False).set()` will cause the sample log-probabilities to be named
22+
`<action_key>_log_prob` for any probability distribution, not just composite ones. For regular, single-head policies
23+
for instance, the log-probability will be named `"action_log_prob"`.
24+
Previously, log-prob keys defaulted to `sample_log_prob`.
25+
2. **Action Grouping**:
26+
- Actions can be grouped or not; PPO doesn't require them to be grouped.
27+
- If actions are grouped, calling the policy will result in a `TensorDict` with fields for each agent's action and
28+
log-probability, e.g., `agent0`, `agent0_log_prob`, etc.
29+
30+
... [...]
31+
... action: TensorDict(
32+
... fields={
33+
... agent0: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
34+
... agent0_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
35+
... agent1: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False),
36+
... agent1_log_prob: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False),
37+
... agent2: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
38+
... agent2_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
39+
... batch_size=torch.Size([4]),
40+
... device=None,
41+
... is_shared=False),
42+
43+
- If actions are not grouped, each agent will have its own `TensorDict` with `action` and `action_log_prob` fields.
44+
45+
... [...]
46+
... agent0: TensorDict(
47+
... fields={
48+
... action: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
49+
... action_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
50+
... batch_size=torch.Size([4]),
51+
... device=None,
52+
... is_shared=False),
53+
... agent1: TensorDict(
54+
... fields={
55+
... action: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False),
56+
... action_log_prob: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
57+
... batch_size=torch.Size([4]),
58+
... device=None,
59+
... is_shared=False),
60+
... agent2: TensorDict(
61+
... fields={
62+
... action: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
63+
... action_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
64+
... batch_size=torch.Size([4]),
65+
... device=None,
66+
... is_shared=False),
67+
68+
3. **PPO Loss Calculation**:
69+
- Under the hood, `ClipPPO` will clip individual weights (not the aggregate) and multiply that by the advantage.
70+
71+
The code below sets up a multi-head agent with three distributions and demonstrates how to train it using PPO losses.
72+
73+
"""
74+
75+
import functools
76+
77+
import torch
78+
from tensordict import TensorDict
79+
from tensordict.nn import (
80+
CompositeDistribution,
81+
InteractionType,
82+
ProbabilisticTensorDictModule as Prob,
83+
ProbabilisticTensorDictSequential as ProbSeq,
84+
set_composite_lp_aggregate,
85+
TensorDictModule as Mod,
86+
TensorDictSequential as Seq,
87+
WrapModule as Wrap,
88+
)
89+
from torch import distributions as d
90+
from torchrl.objectives import ClipPPOLoss, KLPENPPOLoss, PPOLoss
91+
92+
set_composite_lp_aggregate(False).set()
93+
94+
GROUPED_ACTIONS = False
95+
96+
make_params = Mod(
97+
lambda: (
98+
torch.ones(4),
99+
torch.ones(4),
100+
torch.ones(4, 2),
101+
torch.ones(4, 2),
102+
torch.ones(4, 10) / 10,
103+
torch.zeros(4, 10),
104+
torch.ones(4, 10),
105+
),
106+
in_keys=[],
107+
out_keys=[
108+
("params", "gamma", "concentration"),
109+
("params", "gamma", "rate"),
110+
("params", "Kumaraswamy", "concentration0"),
111+
("params", "Kumaraswamy", "concentration1"),
112+
("params", "mixture", "logits"),
113+
("params", "mixture", "loc"),
114+
("params", "mixture", "scale"),
115+
],
116+
)
117+
118+
119+
def mixture_constructor(logits, loc, scale):
120+
return d.MixtureSameFamily(
121+
d.Categorical(logits=logits), d.Normal(loc=loc, scale=scale)
122+
)
123+
124+
125+
if GROUPED_ACTIONS:
126+
name_map = {
127+
"gamma": ("action", "agent0"),
128+
"Kumaraswamy": ("action", "agent1"),
129+
"mixture": ("action", "agent2"),
130+
}
131+
else:
132+
name_map = {
133+
"gamma": ("agent0", "action"),
134+
"Kumaraswamy": ("agent1", "action"),
135+
"mixture": ("agent2", "action"),
136+
}
137+
138+
dist_constructor = functools.partial(
139+
CompositeDistribution,
140+
distribution_map={
141+
"gamma": d.Gamma,
142+
"Kumaraswamy": d.Kumaraswamy,
143+
"mixture": mixture_constructor,
144+
},
145+
name_map=name_map,
146+
)
147+
148+
149+
policy = ProbSeq(
150+
make_params,
151+
Prob(
152+
in_keys=["params"],
153+
out_keys=list(name_map.values()),
154+
distribution_class=dist_constructor,
155+
return_log_prob=True,
156+
default_interaction_type=InteractionType.RANDOM,
157+
),
158+
)
159+
160+
td = policy(TensorDict(batch_size=[4]))
161+
print("Result of policy call", td)
162+
163+
dist = policy.get_dist(td)
164+
log_prob = dist.log_prob(td)
165+
print("Composite log-prob", log_prob)
166+
167+
# Build a dummy value operator
168+
value_operator = Seq(
169+
Wrap(
170+
lambda td: td.set("state_value", torch.ones((*td.shape, 1))),
171+
out_keys=["state_value"],
172+
)
173+
)
174+
175+
# Create fake data
176+
data = policy(TensorDict(batch_size=[4]))
177+
data.set(
178+
"next",
179+
TensorDict(reward=torch.randn(4, 1), done=torch.zeros(4, 1, dtype=torch.bool)),
180+
)
181+
182+
# Instantiate the loss - test the 3 different PPO losses
183+
for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss):
184+
# PPO sets the keys automatically by looking at the policy
185+
ppo = loss_cls(policy, value_operator)
186+
print("tensor keys", ppo.tensor_keys)
187+
188+
# Get the loss values
189+
loss_vals = ppo(data)
190+
print("Loss result:", loss_cls, loss_vals)

0 commit comments

Comments
 (0)