Skip to content

[Feature] Adds per-head entropy coefficients to PPOLoss #2972

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 75 additions & 1 deletion test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@
from torchrl.objectives.redq import REDQLoss
from torchrl.objectives.reinforce import ReinforceLoss
from torchrl.objectives.utils import (
_sum_td_features,
_vmap_func,
HardUpdate,
hold_out_net,
Expand Down Expand Up @@ -9734,7 +9735,8 @@ def mixture_constructor(logits, loc, scale):
reward=torch.randn(4, 1), done=torch.zeros(4, 1, dtype=torch.bool)
),
)
ppo = cls(policy, value_operator)
scalar_entropy = 0.07
ppo = cls(policy, value_operator, entropy_coef=scalar_entropy)
ppo.set_keys(
action=[
("agent0", "action"),
Expand All @@ -9748,8 +9750,50 @@ def mixture_constructor(logits, loc, scale):
],
)
loss = ppo(data)
composite_entropy = loss["composite_entropy"]
entropy = _sum_td_features(composite_entropy)
expected_loss = -(scalar_entropy * entropy).mean() # batch mean
torch.testing.assert_close(
loss["loss_entropy"], expected_loss, rtol=1e-5, atol=1e-7
)
loss.sum(reduce=True)

# keep per-head entropies instead of the aggregated tensor
set_composite_lp_aggregate(False).set()
coef_map = {
"agent0": 0.10,
"agent1": 0.05,
"agent2": 0.02,
}
ppo_weighted = cls(policy, value_operator, entropy_coef=coef_map)
ppo_weighted.set_keys(
action=[
("agent0", "action"),
("agent1", "action"),
("agent2", "action"),
],
sample_log_prob=[
("agent0", "action_log_prob"),
("agent1", "action_log_prob"),
("agent2", "action_log_prob"),
],
)
loss = ppo_weighted(data)
composite_entropy = loss["composite_entropy"]

# sanity check: loss_entropy is scalar + finite
assert loss["loss_entropy"].ndim == 0
assert torch.isfinite(loss["loss_entropy"])
# Check individual loss is computed with the right weights
expected_loss = 0.0
for name, head_entropy in composite_entropy.items():
expected_loss -= (
coef_map[name] * _sum_td_features(head_entropy)
).mean()
torch.testing.assert_close(
loss["loss_entropy"], expected_loss, rtol=1e-5, atol=1e-7
)

def test_ppo_marl_aggregate(self):
env = MARLEnv()

Expand Down Expand Up @@ -9791,6 +9835,36 @@ def primer(td):
assert isinstance(ppo.tensor_keys.action, list)
ppo(output)

def _make_entropy_loss(self, entropy_coef):
actor, critic = self._create_mock_actor_value()
return PPOLoss(actor, critic, entropy_coef=entropy_coef)

def test_weighted_entropy_scalar(self):
loss = self._make_entropy_loss(entropy_coef=0.5)
entropy = torch.tensor(2.0)
out = loss._weighted_loss_entropy(entropy)
torch.testing.assert_close(out, torch.tensor(-1.0))

def test_weighted_entropy_mapping(self):
coef = {"head_0": 0.3, "head_1": 0.7}
loss = self._make_entropy_loss(entropy_coef=coef)
entropy = TensorDict(
{
"head_0": {"action_log_prob": torch.tensor(1.0)},
"head_1": {"action_log_prob": torch.tensor(2.0)},
},
[],
)
out = loss._weighted_loss_entropy(entropy)
expected = -(coef["head_0"] * 1.0 + coef["head_1"] * 2.0)
torch.testing.assert_close(out, torch.tensor(expected))

def test_weighted_entropy_mapping_missing_key(self):
loss = self._make_entropy_loss(entropy_coef={"head_not_present": 0.5})
entropy = TensorDict({"head_0": {"action_log_prob": torch.tensor(1.0)}}, [])
with pytest.raises(KeyError):
loss._weighted_loss_entropy(entropy)


class TestA2C(LossModuleTestBase):
seed = 0
Expand Down
94 changes: 76 additions & 18 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import warnings
from copy import deepcopy
from dataclasses import dataclass
from typing import Mapping

import torch
from tensordict import (
Expand Down Expand Up @@ -84,7 +85,9 @@ class PPOLoss(LossModule):
``samples_mc_entropy`` will control how many
samples will be used to compute this estimate.
Defaults to ``1``.
entropy_coef (scalar, optional): entropy multiplier when computing the total loss.
entropy_coef (scalar | Mapping[str, scalar], optional): entropy multiplier when computing the total loss.
* **Scalar**: one value applied to the summed entropy of every action head.
* **Mapping** ``{head_name: coef}`` gives an individual coefficient for each action-head's entropy.
Defaults to ``0.01``.
critic_coef (scalar, optional): critic loss multiplier when computing the total
loss. Defaults to ``1.0``. Set ``critic_coef`` to ``None`` to exclude the value
Expand Down Expand Up @@ -330,7 +333,7 @@ def __init__(
*,
entropy_bonus: bool = True,
samples_mc_entropy: int = 1,
entropy_coef: float = 0.01,
entropy_coef: float | Mapping[str, float] = 0.01,
critic_coef: float | None = None,
loss_critic_type: str = "smooth_l1",
normalize_advantage: bool = False,
Expand Down Expand Up @@ -408,7 +411,22 @@ def __init__(
torch, "get_default_device", lambda: torch.device("cpu")
)()

self.register_buffer("entropy_coef", torch.tensor(entropy_coef, device=device))
if isinstance(entropy_coef, Mapping):
# Store the mapping for per-head coefficients
self._entropy_coef_map = {str(k): float(v) for k, v in entropy_coef.items()}
# Register an empty buffer for compatibility
self.register_buffer("entropy_coef", torch.tensor(0.0))
elif isinstance(entropy_coef, (float, int, torch.Tensor)):
# Register the scalar entropy coefficient
coef = (
float(entropy_coef)
if not torch.is_tensor(entropy_coef)
else float(entropy_coef.item())
)
self.register_buffer("entropy_coef", torch.tensor(coef))
self._entropy_coef_map = None
else:
raise TypeError("entropy_coef must be a float or a Mapping[str, float]")
if critic_coef is not None:
self.register_buffer(
"critic_coef", torch.tensor(critic_coef, device=device)
Expand Down Expand Up @@ -540,7 +558,6 @@ def _get_entropy(
return entropy.unsqueeze(-1)

def _get_cur_log_prob(self, tensordict):

if isinstance(
self.actor_network,
(ProbabilisticTensorDictSequential, ProbabilisticTensorDictModule),
Expand Down Expand Up @@ -589,7 +606,6 @@ def _get_cur_log_prob(self, tensordict):
def _log_weight(
self, tensordict: TensorDictBase, adv_shape: torch.Size
) -> tuple[torch.Tensor, d.Distribution, torch.Tensor]:

prev_log_prob = _maybe_get_or_select(
tensordict,
self.tensor_keys.sample_log_prob,
Expand Down Expand Up @@ -745,9 +761,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
if is_tensor_collection(entropy):
# Reports the entropy of each action head.
td_out.set("composite_entropy", entropy.detach())
entropy = _sum_td_features(entropy)
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy)
td_out.set(
"entropy", _sum_td_features(entropy).detach().mean()
) # for logging
else:
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("loss_entropy", self._weighted_loss_entropy(entropy))
if self._has_critic:
loss_critic, value_clip_fraction = self.loss_critic(tensordict)
td_out.set("loss_critic", loss_critic)
Expand Down Expand Up @@ -814,6 +833,35 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
}
self._value_estimator.set_keys(**tensor_keys)

def _weighted_loss_entropy(
self, entropy: torch.Tensor | TensorDictBase
) -> torch.Tensor:
"""Compute the weighted entropy loss.

If `self._entropy_coef_map` is provided, apply per-head entropy coefficients.
Otherwise, use the scalar `self.entropy_coef`.
"""
if self._entropy_coef_map is None:
if is_tensor_collection(entropy):
entropy = _sum_td_features(entropy)
return -self.entropy_coef * entropy

loss_term = None # running sum over heads
for head_name, entropy_head in entropy.items():
try:
coeff = self._entropy_coef_map[head_name]
except KeyError as exc:
raise KeyError(f"Missing entropy coef for head '{head_name}'") from exc
coeff_t = torch.as_tensor(
coeff, dtype=entropy_head.dtype, device=entropy_head.device
)
head_loss_term = -coeff_t * _sum_td_features(entropy_head)
loss_term = (
head_loss_term if loss_term is None else loss_term + head_loss_term
) # accumulate

return loss_term


class ClipPPOLoss(PPOLoss):
"""Clipped PPO loss.
Expand All @@ -836,7 +884,9 @@ class ClipPPOLoss(PPOLoss):
``samples_mc_entropy`` will control how many
samples will be used to compute this estimate.
Defaults to ``1``.
entropy_coef (scalar, optional): entropy multiplier when computing the total loss.
entropy_coef (scalar | Mapping[str, scalar], optional): entropy multiplier when computing the total loss.
* **Scalar**: one value applied to the summed entropy of every action head.
* **Mapping** ``{head_name: coef}`` gives an individual coefficient for each action-head's entropy.
Defaults to ``0.01``.
critic_coef (scalar, optional): critic loss multiplier when computing the total
loss. Defaults to ``1.0``. Set ``critic_coef`` to ``None`` to exclude the value
Expand Down Expand Up @@ -939,7 +989,7 @@ def __init__(
clip_epsilon: float = 0.2,
entropy_bonus: bool = True,
samples_mc_entropy: int = 1,
entropy_coef: float = 0.01,
entropy_coef: float | Mapping[str, float] = 0.01,
critic_coef: float | None = None,
loss_critic_type: str = "smooth_l1",
normalize_advantage: bool = False,
Expand Down Expand Up @@ -1064,9 +1114,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
if is_tensor_collection(entropy):
# Reports the entropy of each action head.
td_out.set("composite_entropy", entropy.detach())
entropy = _sum_td_features(entropy)
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy)
td_out.set(
"entropy", _sum_td_features(entropy).detach().mean()
) # for logging
else:
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("loss_entropy", self._weighted_loss_entropy(entropy))
if self._has_critic:
loss_critic, value_clip_fraction = self.loss_critic(tensordict)
td_out.set("loss_critic", loss_critic)
Expand Down Expand Up @@ -1120,7 +1173,9 @@ class KLPENPPOLoss(PPOLoss):
``samples_mc_entropy`` will control how many
samples will be used to compute this estimate.
Defaults to ``1``.
entropy_coef (scalar, optional): entropy multiplier when computing the total loss.
entropy_coef (scalar | Mapping[str, scalar], optional): entropy multiplier when computing the total loss.
* **Scalar**: one value applied to the summed entropy of every action head.
* **Mapping** ``{head_name: coef}`` gives an individual coefficient for each action-head's entropy.
Defaults to ``0.01``.
critic_coef (scalar, optional): critic loss multiplier when computing the total
loss. Defaults to ``1.0``.
Expand Down Expand Up @@ -1224,7 +1279,7 @@ def __init__(
samples_mc_kl: int = 1,
entropy_bonus: bool = True,
samples_mc_entropy: int = 1,
entropy_coef: float = 0.01,
entropy_coef: float | Mapping[str, float] = 0.01,
critic_coef: float | None = None,
loss_critic_type: str = "smooth_l1",
normalize_advantage: bool = False,
Expand Down Expand Up @@ -1405,9 +1460,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
if is_tensor_collection(entropy):
# Reports the entropy of each action head.
td_out.set("composite_entropy", entropy.detach())
entropy = _sum_td_features(entropy)
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy)
td_out.set(
"entropy", _sum_td_features(entropy).detach().mean()
) # for logging
else:
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("loss_entropy", self._weighted_loss_entropy(entropy))
if self._has_critic:
loss_critic, value_clip_fraction = self.loss_critic(tensordict_copy)
td_out.set("loss_critic", loss_critic)
Expand Down
Loading