|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# |
| 3 | +# This source code is licensed under the MIT license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +""" |
| 7 | +Multi-head Agent and PPO Loss |
| 8 | +============================= |
| 9 | +This example demonstrates how to use TorchRL to create a multi-head agent with three separate distributions |
| 10 | +(Gamma, Kumaraswamy, and Mixture) and train it using Proximal Policy Optimization (PPO) losses. |
| 11 | +
|
| 12 | +Step-by-step Explanation |
| 13 | +------------------------ |
| 14 | +
|
| 15 | +1. **Setting Composite Log-Probabilities**: |
| 16 | + - To use composite (=multi-head0 distributions with PPO (or any other algorithm that relies on probability distributions like SAC |
| 17 | + or A2C), you must call `set_composite_lp_aggregate(False).set()`. Not calling this will result in errors during |
| 18 | + execution of your script. |
| 19 | + - From torchrl and tensordict v0.9, this will be the default behavior. Not doing this will result in |
| 20 | + `CompositeDistribution` aggregating the log-probs, which may lead to incorrect log-probabilities. |
| 21 | + - Note that `set_composite_lp_aggregate(False).set()` will cause the sample log-probabilities to be named |
| 22 | + `<action_key>_log_prob` for any probability distribution, not just composite ones. For regular, single-head policies |
| 23 | + for instance, the log-probability will be named `"action_log_prob"`. |
| 24 | + Previously, log-prob keys defaulted to `sample_log_prob`. |
| 25 | +2. **Action Grouping**: |
| 26 | + - Actions can be grouped or not; PPO doesn't require them to be grouped. |
| 27 | + - If actions are grouped, calling the policy will result in a `TensorDict` with fields for each agent's action and |
| 28 | + log-probability, e.g., `agent0`, `agent0_log_prob`, etc. |
| 29 | +
|
| 30 | + ... [...] |
| 31 | + ... action: TensorDict( |
| 32 | + ... fields={ |
| 33 | + ... agent0: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), |
| 34 | + ... agent0_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), |
| 35 | + ... agent1: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False), |
| 36 | + ... agent1_log_prob: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False), |
| 37 | + ... agent2: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), |
| 38 | + ... agent2_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)}, |
| 39 | + ... batch_size=torch.Size([4]), |
| 40 | + ... device=None, |
| 41 | + ... is_shared=False), |
| 42 | +
|
| 43 | + - If actions are not grouped, each agent will have its own `TensorDict` with `action` and `action_log_prob` fields. |
| 44 | +
|
| 45 | + ... [...] |
| 46 | + ... agent0: TensorDict( |
| 47 | + ... fields={ |
| 48 | + ... action: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), |
| 49 | + ... action_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)}, |
| 50 | + ... batch_size=torch.Size([4]), |
| 51 | + ... device=None, |
| 52 | + ... is_shared=False), |
| 53 | + ... agent1: TensorDict( |
| 54 | + ... fields={ |
| 55 | + ... action: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False), |
| 56 | + ... action_log_prob: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False)}, |
| 57 | + ... batch_size=torch.Size([4]), |
| 58 | + ... device=None, |
| 59 | + ... is_shared=False), |
| 60 | + ... agent2: TensorDict( |
| 61 | + ... fields={ |
| 62 | + ... action: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), |
| 63 | + ... action_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)}, |
| 64 | + ... batch_size=torch.Size([4]), |
| 65 | + ... device=None, |
| 66 | + ... is_shared=False), |
| 67 | +
|
| 68 | +3. **PPO Loss Calculation**: |
| 69 | + - Under the hood, `ClipPPO` will clip individual weights (not the aggregate) and multiply that by the advantage. |
| 70 | +
|
| 71 | +The code below sets up a multi-head agent with three distributions and demonstrates how to train it using PPO losses. |
| 72 | +
|
| 73 | +""" |
| 74 | + |
| 75 | +import functools |
| 76 | + |
| 77 | +import torch |
| 78 | +from tensordict import TensorDict |
| 79 | +from tensordict.nn import ( |
| 80 | + CompositeDistribution, |
| 81 | + InteractionType, |
| 82 | + ProbabilisticTensorDictModule as Prob, |
| 83 | + ProbabilisticTensorDictSequential as ProbSeq, |
| 84 | + set_composite_lp_aggregate, |
| 85 | + TensorDictModule as Mod, |
| 86 | + TensorDictSequential as Seq, |
| 87 | + WrapModule as Wrap, |
| 88 | +) |
| 89 | +from torch import distributions as d |
| 90 | +from torchrl.objectives import ClipPPOLoss, KLPENPPOLoss, PPOLoss |
| 91 | + |
| 92 | +set_composite_lp_aggregate(False).set() |
| 93 | + |
| 94 | +GROUPED_ACTIONS = False |
| 95 | + |
| 96 | +make_params = Mod( |
| 97 | + lambda: ( |
| 98 | + torch.ones(4), |
| 99 | + torch.ones(4), |
| 100 | + torch.ones(4, 2), |
| 101 | + torch.ones(4, 2), |
| 102 | + torch.ones(4, 10) / 10, |
| 103 | + torch.zeros(4, 10), |
| 104 | + torch.ones(4, 10), |
| 105 | + ), |
| 106 | + in_keys=[], |
| 107 | + out_keys=[ |
| 108 | + ("params", "gamma", "concentration"), |
| 109 | + ("params", "gamma", "rate"), |
| 110 | + ("params", "Kumaraswamy", "concentration0"), |
| 111 | + ("params", "Kumaraswamy", "concentration1"), |
| 112 | + ("params", "mixture", "logits"), |
| 113 | + ("params", "mixture", "loc"), |
| 114 | + ("params", "mixture", "scale"), |
| 115 | + ], |
| 116 | +) |
| 117 | + |
| 118 | + |
| 119 | +def mixture_constructor(logits, loc, scale): |
| 120 | + return d.MixtureSameFamily( |
| 121 | + d.Categorical(logits=logits), d.Normal(loc=loc, scale=scale) |
| 122 | + ) |
| 123 | + |
| 124 | + |
| 125 | +if GROUPED_ACTIONS: |
| 126 | + name_map = { |
| 127 | + "gamma": ("action", "agent0"), |
| 128 | + "Kumaraswamy": ("action", "agent1"), |
| 129 | + "mixture": ("action", "agent2"), |
| 130 | + } |
| 131 | +else: |
| 132 | + name_map = { |
| 133 | + "gamma": ("agent0", "action"), |
| 134 | + "Kumaraswamy": ("agent1", "action"), |
| 135 | + "mixture": ("agent2", "action"), |
| 136 | + } |
| 137 | + |
| 138 | +dist_constructor = functools.partial( |
| 139 | + CompositeDistribution, |
| 140 | + distribution_map={ |
| 141 | + "gamma": d.Gamma, |
| 142 | + "Kumaraswamy": d.Kumaraswamy, |
| 143 | + "mixture": mixture_constructor, |
| 144 | + }, |
| 145 | + name_map=name_map, |
| 146 | +) |
| 147 | + |
| 148 | + |
| 149 | +policy = ProbSeq( |
| 150 | + make_params, |
| 151 | + Prob( |
| 152 | + in_keys=["params"], |
| 153 | + out_keys=list(name_map.values()), |
| 154 | + distribution_class=dist_constructor, |
| 155 | + return_log_prob=True, |
| 156 | + default_interaction_type=InteractionType.RANDOM, |
| 157 | + ), |
| 158 | +) |
| 159 | + |
| 160 | +td = policy(TensorDict(batch_size=[4])) |
| 161 | +print("Result of policy call", td) |
| 162 | + |
| 163 | +dist = policy.get_dist(td) |
| 164 | +log_prob = dist.log_prob(td) |
| 165 | +print("Composite log-prob", log_prob) |
| 166 | + |
| 167 | +# Build a dummy value operator |
| 168 | +value_operator = Seq( |
| 169 | + Wrap( |
| 170 | + lambda td: td.set("state_value", torch.ones((*td.shape, 1))), |
| 171 | + out_keys=["state_value"], |
| 172 | + ) |
| 173 | +) |
| 174 | + |
| 175 | +# Create fake data |
| 176 | +data = policy(TensorDict(batch_size=[4])) |
| 177 | +data.set( |
| 178 | + "next", |
| 179 | + TensorDict(reward=torch.randn(4, 1), done=torch.zeros(4, 1, dtype=torch.bool)), |
| 180 | +) |
| 181 | + |
| 182 | +# Instantiate the loss - test the 3 different PPO losses |
| 183 | +for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss): |
| 184 | + # PPO sets the keys automatically by looking at the policy |
| 185 | + ppo = loss_cls(policy, value_operator) |
| 186 | + print("tensor keys", ppo.tensor_keys) |
| 187 | + |
| 188 | + # Get the loss values |
| 189 | + loss_vals = ppo(data) |
| 190 | + print("Loss result:", loss_cls, loss_vals) |
0 commit comments