Skip to content

Commit a2145e1

Browse files
committed
[Feature] Make PPO compatible with composite actions and log-probs
ghstack-source-id: 3bcf7eb Pull Request resolved: #2665
1 parent d009835 commit a2145e1

File tree

5 files changed

+455
-70
lines changed

5 files changed

+455
-70
lines changed

examples/agents/composite_ppo.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""
7+
Multi-head agent and PPO loss
8+
=============================
9+
10+
This example demonstrates how to use TorchRL to create a multi-head agent with three separate distributions
11+
(Gamma, Kumaraswamy, and Mixture) and train it using Proximal Policy Optimization (PPO) losses.
12+
13+
The code first defines a module `make_params` that extracts the parameters of the distributions from an input tensordict.
14+
It then creates a `dist_constructor` function that takes these parameters as input and outputs a CompositeDistribution
15+
object containing the three distributions.
16+
17+
The policy is defined as a ProbabilisticTensorDictSequential module that reads an observation, casts it to parameters,
18+
creates a distribution from these parameters, and samples from the distribution to output multiple actions.
19+
20+
The example tests the policy with fake data across three different PPO losses: PPOLoss, ClipPPOLoss, and KLPENPPOLoss.
21+
22+
Note that the `log_prob` method of the CompositeDistribution object can return either an aggregated tensor or a
23+
fine-grained tensordict with individual log-probabilities, depending on the value of the `aggregate_probabilities`
24+
argument. The PPO loss modules are designed to handle both cases, and will default to `aggregate_probabilities=False`
25+
if not specified.
26+
27+
In particular, if `aggregate_probabilities=False` and `include_sum=True`, the summed log-probs will also be included in
28+
the output tensordict. However, since we have access to the individual log-probs, this feature is not typically used.
29+
30+
"""
31+
32+
import functools
33+
34+
import torch
35+
from tensordict import TensorDict
36+
from tensordict.nn import (
37+
CompositeDistribution,
38+
InteractionType,
39+
ProbabilisticTensorDictModule as Prob,
40+
ProbabilisticTensorDictSequential as ProbSeq,
41+
TensorDictModule as Mod,
42+
TensorDictSequential as Seq,
43+
WrapModule as Wrap,
44+
)
45+
from torch import distributions as d
46+
from torchrl.objectives import ClipPPOLoss, KLPENPPOLoss, PPOLoss
47+
48+
make_params = Mod(
49+
lambda: (
50+
torch.ones(4),
51+
torch.ones(4),
52+
torch.ones(4, 2),
53+
torch.ones(4, 2),
54+
torch.ones(4, 10) / 10,
55+
torch.zeros(4, 10),
56+
torch.ones(4, 10),
57+
),
58+
in_keys=[],
59+
out_keys=[
60+
("params", "gamma", "concentration"),
61+
("params", "gamma", "rate"),
62+
("params", "Kumaraswamy", "concentration0"),
63+
("params", "Kumaraswamy", "concentration1"),
64+
("params", "mixture", "logits"),
65+
("params", "mixture", "loc"),
66+
("params", "mixture", "scale"),
67+
],
68+
)
69+
70+
71+
def mixture_constructor(logits, loc, scale):
72+
return d.MixtureSameFamily(
73+
d.Categorical(logits=logits), d.Normal(loc=loc, scale=scale)
74+
)
75+
76+
77+
# =============================================================================
78+
# Example 0: aggregate_probabilities=None (default) ===========================
79+
80+
dist_constructor = functools.partial(
81+
CompositeDistribution,
82+
distribution_map={
83+
"gamma": d.Gamma,
84+
"Kumaraswamy": d.Kumaraswamy,
85+
"mixture": mixture_constructor,
86+
},
87+
name_map={
88+
"gamma": ("agent0", "action"),
89+
"Kumaraswamy": ("agent1", "action"),
90+
"mixture": ("agent2", "action"),
91+
},
92+
aggregate_probabilities=None,
93+
)
94+
95+
96+
policy = ProbSeq(
97+
make_params,
98+
Prob(
99+
in_keys=["params"],
100+
out_keys=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")],
101+
distribution_class=dist_constructor,
102+
return_log_prob=True,
103+
default_interaction_type=InteractionType.RANDOM,
104+
),
105+
)
106+
107+
td = policy(TensorDict(batch_size=[4]))
108+
print("0. result of policy call", td)
109+
110+
dist = policy.get_dist(td)
111+
log_prob = dist.log_prob(
112+
td, aggregate_probabilities=False, inplace=False, include_sum=False
113+
)
114+
print("0. non-aggregated log-prob")
115+
116+
# We can also get the log-prob from the policy directly
117+
log_prob = policy.log_prob(
118+
td, aggregate_probabilities=False, inplace=False, include_sum=False
119+
)
120+
print("0. non-aggregated log-prob (from policy)")
121+
122+
# Build a dummy value operator
123+
value_operator = Seq(
124+
Wrap(
125+
lambda td: td.set("state_value", torch.ones((*td.shape, 1))),
126+
out_keys=["state_value"],
127+
)
128+
)
129+
130+
# Create fake data
131+
data = policy(TensorDict(batch_size=[4]))
132+
data.set(
133+
"next",
134+
TensorDict(reward=torch.randn(4, 1), done=torch.zeros(4, 1, dtype=torch.bool)),
135+
)
136+
137+
# Instantiate the loss
138+
for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss):
139+
ppo = loss_cls(policy, value_operator)
140+
141+
# Keys are not the default ones - there is more than one action
142+
ppo.set_keys(
143+
action=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")],
144+
sample_log_prob=[
145+
("agent0", "action_log_prob"),
146+
("agent1", "action_log_prob"),
147+
("agent2", "action_log_prob"),
148+
],
149+
)
150+
151+
# Get the loss values
152+
loss_vals = ppo(data)
153+
print("0. ", loss_cls, loss_vals)
154+
155+
156+
# ===================================================================
157+
# Example 1: aggregate_probabilities=True ===========================
158+
159+
dist_constructor.keywords["aggregate_probabilities"] = True
160+
161+
td = policy(TensorDict(batch_size=[4]))
162+
print("1. result of policy call", td)
163+
164+
# Instantiate the loss
165+
for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss):
166+
ppo = loss_cls(policy, value_operator)
167+
168+
# Keys are not the default ones - there is more than one action. No need to indicate the sample-log-prob key, since
169+
# there is only one.
170+
ppo.set_keys(
171+
action=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")]
172+
)
173+
174+
# Get the loss values
175+
loss_vals = ppo(data)
176+
print("1. ", loss_cls, loss_vals)
177+
178+
179+
# ===================================================================
180+
# Example 2: aggregate_probabilities=False ===========================
181+
182+
dist_constructor.keywords["aggregate_probabilities"] = False
183+
184+
td = policy(TensorDict(batch_size=[4]))
185+
print("2. result of policy call", td)
186+
187+
# Instantiate the loss
188+
for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss):
189+
ppo = loss_cls(policy, value_operator)
190+
191+
# Keys are not the default ones - there is more than one action
192+
ppo.set_keys(
193+
action=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")],
194+
sample_log_prob=[
195+
("agent0", "action_log_prob"),
196+
("agent1", "action_log_prob"),
197+
("agent2", "action_log_prob"),
198+
],
199+
)
200+
201+
# Get the loss values
202+
loss_vals = ppo(data)
203+
print("2. ", loss_cls, loss_vals)

test/test_cost.py

Lines changed: 106 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
TensorDictModule as Mod,
3535
TensorDictSequential,
3636
TensorDictSequential as Seq,
37+
WrapModule,
3738
)
3839
from tensordict.nn.utils import Buffer
3940
from tensordict.utils import unravel_key
@@ -8864,9 +8865,7 @@ def test_ppo_tensordict_keys_run(
88648865
@pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"])
88658866
@pytest.mark.parametrize(
88668867
"composite_action_dist",
8867-
[
8868-
False,
8869-
],
8868+
[False],
88708869
)
88718870
def test_ppo_notensordict(
88728871
self,
@@ -9060,6 +9059,110 @@ def test_ppo_value_clipping(
90609059
loss = loss_fn(td)
90619060
assert "loss_critic" in loss.keys()
90629061

9062+
def test_ppo_composite_dists(self):
9063+
d = torch.distributions
9064+
9065+
make_params = TensorDictModule(
9066+
lambda: (
9067+
torch.ones(4),
9068+
torch.ones(4),
9069+
torch.ones(4, 2),
9070+
torch.ones(4, 2),
9071+
torch.ones(4, 10) / 10,
9072+
torch.zeros(4, 10),
9073+
torch.ones(4, 10),
9074+
),
9075+
in_keys=[],
9076+
out_keys=[
9077+
("params", "gamma", "concentration"),
9078+
("params", "gamma", "rate"),
9079+
("params", "Kumaraswamy", "concentration0"),
9080+
("params", "Kumaraswamy", "concentration1"),
9081+
("params", "mixture", "logits"),
9082+
("params", "mixture", "loc"),
9083+
("params", "mixture", "scale"),
9084+
],
9085+
)
9086+
9087+
def mixture_constructor(logits, loc, scale):
9088+
return d.MixtureSameFamily(
9089+
d.Categorical(logits=logits), d.Normal(loc=loc, scale=scale)
9090+
)
9091+
9092+
dist_constructor = functools.partial(
9093+
CompositeDistribution,
9094+
distribution_map={
9095+
"gamma": d.Gamma,
9096+
"Kumaraswamy": d.Kumaraswamy,
9097+
"mixture": mixture_constructor,
9098+
},
9099+
name_map={
9100+
"gamma": ("agent0", "action"),
9101+
"Kumaraswamy": ("agent1", "action"),
9102+
"mixture": ("agent2", "action"),
9103+
},
9104+
aggregate_probabilities=False,
9105+
include_sum=False,
9106+
inplace=True,
9107+
)
9108+
policy = ProbSeq(
9109+
make_params,
9110+
ProbabilisticTensorDictModule(
9111+
in_keys=["params"],
9112+
out_keys=[
9113+
("agent0", "action"),
9114+
("agent1", "action"),
9115+
("agent2", "action"),
9116+
],
9117+
distribution_class=dist_constructor,
9118+
return_log_prob=True,
9119+
default_interaction_type=InteractionType.RANDOM,
9120+
),
9121+
)
9122+
# We want to make sure there is no warning
9123+
td = policy(TensorDict(batch_size=[4]))
9124+
assert isinstance(
9125+
policy.get_dist(td).log_prob(
9126+
td, aggregate_probabilities=False, inplace=False, include_sum=False
9127+
),
9128+
TensorDict,
9129+
)
9130+
assert isinstance(
9131+
policy.log_prob(
9132+
td, aggregate_probabilities=False, inplace=False, include_sum=False
9133+
),
9134+
TensorDict,
9135+
)
9136+
value_operator = Seq(
9137+
WrapModule(
9138+
lambda td: td.set("state_value", torch.ones((*td.shape, 1))),
9139+
out_keys=["state_value"],
9140+
)
9141+
)
9142+
for cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss):
9143+
data = policy(TensorDict(batch_size=[4]))
9144+
data.set(
9145+
"next",
9146+
TensorDict(
9147+
reward=torch.randn(4, 1), done=torch.zeros(4, 1, dtype=torch.bool)
9148+
),
9149+
)
9150+
ppo = cls(policy, value_operator)
9151+
ppo.set_keys(
9152+
action=[
9153+
("agent0", "action"),
9154+
("agent1", "action"),
9155+
("agent2", "action"),
9156+
],
9157+
sample_log_prob=[
9158+
("agent0", "action_log_prob"),
9159+
("agent1", "action_log_prob"),
9160+
("agent2", "action_log_prob"),
9161+
],
9162+
)
9163+
loss = ppo(data)
9164+
loss.sum(reduce=True)
9165+
90639166

90649167
class TestA2C(LossModuleTestBase):
90659168
seed = 0

0 commit comments

Comments
 (0)