From c6016293d07438e41f19b7f41aa8a623655cfbc2 Mon Sep 17 00:00:00 2001 From: Erni <38285979+arnaujc91@users.noreply.github.com> Date: Sun, 3 Mar 2024 00:09:39 +0100 Subject: [PATCH] Using dist.mode instead of logits.argmax (#1066) changed all the occurrences where an action is selected deterministically - **from**: using the outputs of the actor network. - **to**: using the mode of the PyTorch distribution. --------- Co-authored-by: Arnau Jimenez --- test/continuous/test_sac_with_il.py | 2 -- tianshou/policy/modelfree/discrete_sac.py | 2 +- tianshou/policy/modelfree/pg.py | 15 +-------------- tianshou/policy/modelfree/redq.py | 5 ++++- tianshou/policy/modelfree/sac.py | 4 ++-- 5 files changed, 8 insertions(+), 20 deletions(-) diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index c0e17c62c..2b2e12e46 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -3,7 +3,6 @@ import gymnasium as gym import numpy as np -import pytest import torch from torch.utils.tensorboard import SummaryWriter @@ -58,7 +57,6 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") def test_sac_with_il(args: argparse.Namespace = get_args()) -> None: # if you want to use python vector env, please refer to other test scripts # train_envs = env = envpool.make_gymnasium(args.task, num_envs=args.training_num, seed=args.seed) diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index 8a80f184c..b271cbd26 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -109,7 +109,7 @@ def forward( # type: ignore logits, hidden = self.actor(batch.obs, state=state, info=batch.info) dist = Categorical(logits=logits) if self.deterministic_eval and not self.training: - act = logits.argmax(axis=-1) + act = dist.mode else: act = dist.sample() return Batch(logits=logits, act=act, state=hidden, dist=dist) diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 7db588be6..eb6cb5952 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -158,19 +158,6 @@ def process_fn( batch: BatchWithReturnsProtocol return batch - def _get_deterministic_action(self, logits: torch.Tensor) -> torch.Tensor: - if self.action_type == "discrete": - return logits.argmax(-1) - if self.action_type == "continuous": - # assume that the mode of the distribution is the first element - # of the actor's output (the "logits") - return logits[0] - raise RuntimeError( - f"Unknown action type: {self.action_type}. " - f"This should not happen and might be a bug." - f"Supported action types are: 'discrete' and 'continuous'.", - ) - def forward( self, batch: ObsBatchProtocol, @@ -198,7 +185,7 @@ def forward( # in this case, the dist is unused! if self.deterministic_eval and not self.training: - act = self._get_deterministic_action(logits) + act = dist.mode else: act = dist.sample() result = Batch(logits=logits, act=act, state=hidden, dist=dist) diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py index d1b00714d..100a361f4 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/policy/modelfree/redq.py @@ -153,7 +153,10 @@ def forward( # type: ignore loc_scale, h = self.actor(batch.obs, state=state, info=batch.info) loc, scale = loc_scale dist = Independent(Normal(loc, scale), 1) - act = loc if self.deterministic_eval and not self.training else dist.rsample() + if self.deterministic_eval and not self.training: + act = dist.mode + else: + act = dist.rsample() log_prob = dist.log_prob(act).unsqueeze(-1) # apply correction for Tanh squashing when computing logprob from Gaussian # You can check out the original SAC paper (arXiv 1801.01290): Eq 21. diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index f487e33e9..a4336247f 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -56,7 +56,7 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t This is useful when solving "hard exploration" problems. "default" is equivalent to GaussianNoise(sigma=0.1). :param deterministic_eval: whether to use deterministic action - (mean of Gaussian policy) in evaluation mode instead of stochastic + (mode of Gaussian policy) in evaluation mode instead of stochastic action sampled by the policy. Does not affect training. :param action_scaling: whether to map actions from range [-1, 1] to range[action_spaces.low, action_spaces.high]. @@ -177,7 +177,7 @@ def forward( # type: ignore assert isinstance(logits, tuple) dist = Independent(Normal(*logits), 1) if self.deterministic_eval and not self.training: - act = logits[0] + act = dist.mode else: act = dist.rsample() log_prob = dist.log_prob(act).unsqueeze(-1)