From ee90cb5cd33dd6af796f513766989a722dd7c000 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 7 Aug 2024 15:23:08 +0200 Subject: [PATCH 1/2] High-level API: Establish a strong link between the actor and the distribution function (dist_fn) used in policies by creating the distribution function in the actor factory which knows which function is appropriate. Consequently, remove the policy parameter 'dist_fn' from the high-level API because it is determined automatically, eliminating the possibility of misspecification by the user. [breaking change: code must not specify the 'dist_fn' parameter, but persisted objects continue to work as expected] Implements #1194 --- CHANGELOG.md | 7 +++- examples/atari/atari_network.py | 16 ++++++++- examples/mujoco/mujoco_npg_hl.py | 4 --- examples/mujoco/mujoco_ppo_hl.py | 4 --- examples/mujoco/mujoco_ppo_hl_multi.py | 4 --- examples/mujoco/mujoco_trpo_hl.py | 4 --- test/highlevel/test_experiment_builder.py | 1 + tianshou/highlevel/agent.py | 4 +++ tianshou/highlevel/module/actor.py | 42 +++++++++++++++++++--- tianshou/highlevel/params/dist_fn.py | 36 +++++++++++-------- tianshou/highlevel/params/policy_params.py | 27 +++----------- 11 files changed, 89 insertions(+), 60 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 84fd582e2..aa29e70a6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,7 @@ ## Release 1.1.0 -### Api Extensions +### Changes/Improvements - `evaluation`: New package for repeating the same experiment with multiple seeds and aggregating the results. #1074 #1141 #1183 - `data`: - `Batch`: @@ -107,6 +107,11 @@ continuous and discrete cases. #1032 - `utils.net.common.Recurrent` now receives and returns a `RecurrentStateBatch` instead of a dict. #1077 - `AtariEnvFactory` constructor (in examples, so not really breaking) now requires explicit train and test seeds. #1074 - `EnvFactoryRegistered` now requires an explicit `test_seed` in the constructor. #1074 +- `highlevel`: + - The parameter `dist_fn` has been removed from the parameter objects (`PGParams`, `A2CParams`, `PPOParams`, `NPGParams`, `TRPOParams`). + The correct distribution is now determined automatically based on the actor factory being used, avoiding the possibility of + misspecification. Persisted configurations/policies continue to work as expected, but code must not specify the `dist_fn` parameter. + #1194 #1195 ### Tests diff --git a/examples/atari/atari_network.py b/examples/atari/atari_network.py index 4f2a5600a..87797f760 100644 --- a/examples/atari/atari_network.py +++ b/examples/atari/atari_network.py @@ -14,6 +14,8 @@ IntermediateModule, IntermediateModuleFactory, ) +from tianshou.highlevel.params.dist_fn import DistributionFunctionFactoryCategorical +from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont from tianshou.utils.net.common import NetBase from tianshou.utils.net.discrete import Actor, NoisyLinear @@ -246,6 +248,8 @@ def forward( class ActorFactoryAtariDQN(ActorFactory): + USE_SOFTMAX_OUTPUT = False + def __init__( self, scale_obs: bool = True, @@ -274,7 +278,17 @@ def create_module(self, envs: Environments, device: TDevice) -> Actor: ) if self.scale_obs: net = scale_obs(net) - return Actor(net, envs.get_action_shape(), device=device, softmax_output=False).to(device) + return Actor( + net, + envs.get_action_shape(), + device=device, + softmax_output=self.USE_SOFTMAX_OUTPUT, + ).to(device) + + def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None: + return DistributionFunctionFactoryCategorical( + is_probs_input=self.USE_SOFTMAX_OUTPUT, + ).create_dist_fn(envs) class IntermediateModuleFactoryAtariDQN(IntermediateModuleFactory): diff --git a/examples/mujoco/mujoco_npg_hl.py b/examples/mujoco/mujoco_npg_hl.py index ab265a87a..231e735c9 100644 --- a/examples/mujoco/mujoco_npg_hl.py +++ b/examples/mujoco/mujoco_npg_hl.py @@ -12,9 +12,6 @@ ExperimentConfig, NPGExperimentBuilder, ) -from tianshou.highlevel.params.dist_fn import ( - DistributionFunctionFactoryIndependentGaussians, -) from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear from tianshou.highlevel.params.policy_params import NPGParams from tianshou.utils import logging @@ -78,7 +75,6 @@ def main( lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config) if lr_decay else None, - dist_fn=DistributionFunctionFactoryIndependentGaussians(), ), ) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index 27a701b12..af0c5ab8f 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -12,9 +12,6 @@ ExperimentConfig, PPOExperimentBuilder, ) -from tianshou.highlevel.params.dist_fn import ( - DistributionFunctionFactoryIndependentGaussians, -) from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear from tianshou.highlevel.params.policy_params import PPOParams from tianshou.utils import logging @@ -88,7 +85,6 @@ def main( lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config) if lr_decay else None, - dist_fn=DistributionFunctionFactoryIndependentGaussians(), ), ) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) diff --git a/examples/mujoco/mujoco_ppo_hl_multi.py b/examples/mujoco/mujoco_ppo_hl_multi.py index 6d140386e..c9d5a8fde 100644 --- a/examples/mujoco/mujoco_ppo_hl_multi.py +++ b/examples/mujoco/mujoco_ppo_hl_multi.py @@ -26,9 +26,6 @@ PPOExperimentBuilder, ) from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.highlevel.params.dist_fn import ( - DistributionFunctionFactoryIndependentGaussians, -) from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear from tianshou.highlevel.params.policy_params import PPOParams from tianshou.utils import logging @@ -115,7 +112,6 @@ def main( recompute_advantage=True, lr=3e-4, lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config), - dist_fn=DistributionFunctionFactoryIndependentGaussians(), ), ) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) diff --git a/examples/mujoco/mujoco_trpo_hl.py b/examples/mujoco/mujoco_trpo_hl.py index f54d4c312..59929529e 100644 --- a/examples/mujoco/mujoco_trpo_hl.py +++ b/examples/mujoco/mujoco_trpo_hl.py @@ -12,9 +12,6 @@ ExperimentConfig, TRPOExperimentBuilder, ) -from tianshou.highlevel.params.dist_fn import ( - DistributionFunctionFactoryIndependentGaussians, -) from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear from tianshou.highlevel.params.policy_params import TRPOParams from tianshou.utils import logging @@ -82,7 +79,6 @@ def main( lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config) if lr_decay else None, - dist_fn=DistributionFunctionFactoryIndependentGaussians(), ), ) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) diff --git a/test/highlevel/test_experiment_builder.py b/test/highlevel/test_experiment_builder.py index cb52c5ae3..5e61ac832 100644 --- a/test/highlevel/test_experiment_builder.py +++ b/test/highlevel/test_experiment_builder.py @@ -56,6 +56,7 @@ def test_experiment_builder_continuous_default_params(builder_cls: type[Experime @pytest.mark.parametrize( "builder_cls", [ + PGExperimentBuilder, PPOExperimentBuilder, A2CExperimentBuilder, DQNExperimentBuilder, diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index 6c35710a6..03b4e1463 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -273,11 +273,14 @@ def _create_policy(self, envs: Environments, device: TDevice) -> PGPolicy: optim_factory=self.optim_factory, ), ) + dist_fn = self.actor_factory.create_dist_fn(envs) + assert dist_fn is not None return PGPolicy( actor=actor.module, optim=actor.optim, action_space=envs.get_action_space(), observation_space=envs.get_observation_space(), + dist_fn=dist_fn, **kwargs, ) @@ -333,6 +336,7 @@ def _create_kwargs(self, envs: Environments, device: TDevice) -> dict[str, Any]: kwargs["critic"] = actor_critic.critic kwargs["optim"] = actor_critic.optim kwargs["action_space"] = envs.get_action_space() + kwargs["dist_fn"] = self.actor_factory.create_dist_fn(envs) return kwargs def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy: diff --git a/tianshou/highlevel/module/actor.py b/tianshou/highlevel/module/actor.py index 867ece17a..5c1387e35 100644 --- a/tianshou/highlevel/module/actor.py +++ b/tianshou/highlevel/module/actor.py @@ -19,6 +19,11 @@ ) from tianshou.highlevel.module.module_opt import ModuleOpt from tianshou.highlevel.optim import OptimizerFactory +from tianshou.highlevel.params.dist_fn import ( + DistributionFunctionFactoryCategorical, + DistributionFunctionFactoryIndependentGaussians, +) +from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont from tianshou.utils.net import continuous, discrete from tianshou.utils.net.common import BaseActor, ModuleType, Net from tianshou.utils.string import ToStringMixin @@ -47,6 +52,14 @@ class ActorFactory(ModuleFactory, ToStringMixin, ABC): def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module: pass + @abstractmethod + def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None: + """ + :param envs: the environments + :return: the distribution function, which converts the actor's output into a distribution, or None + if the actor does not output distribution parameters + """ + def create_module_opt( self, envs: Environments, @@ -70,7 +83,7 @@ def create_module_opt( def _init_linear(actor: torch.nn.Module) -> None: """Initializes linear layers of an actor module using default mechanisms. - :param module: the actor module. + :param actor: the actor module. """ init_linear_orthogonal(actor) if hasattr(actor, "mu"): @@ -104,7 +117,7 @@ def __init__( self.hidden_activation = hidden_activation self.discrete_softmax = discrete_softmax - def create_module(self, envs: Environments, device: TDevice) -> BaseActor: + def _create_factory(self, envs: Environments) -> ActorFactory: env_type = envs.get_type() factory: ActorFactoryContinuousDeterministicNet | ActorFactoryContinuousGaussianNet | ActorFactoryDiscreteNet if env_type == EnvType.CONTINUOUS: @@ -125,15 +138,22 @@ def create_module(self, envs: Environments, device: TDevice) -> BaseActor: raise ValueError("Continuous action spaces are not supported by the algorithm") case _: raise ValueError(self.continuous_actor_type) - return factory.create_module(envs, device) elif env_type == EnvType.DISCRETE: factory = ActorFactoryDiscreteNet( self.DEFAULT_HIDDEN_SIZES, softmax_output=self.discrete_softmax, ) - return factory.create_module(envs, device) else: raise ValueError(f"{env_type} not supported") + return factory + + def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module: + factory = self._create_factory(envs) + return factory.create_module(envs, device) + + def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None: + factory = self._create_factory(envs) + return factory.create_dist_fn(envs) class ActorFactoryContinuous(ActorFactory, ABC): @@ -159,6 +179,9 @@ def create_module(self, envs: Environments, device: TDevice) -> BaseActor: device=device, ).to(device) + def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None: + return None + class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous): def __init__( @@ -202,6 +225,9 @@ def create_module(self, envs: Environments, device: TDevice) -> BaseActor: return actor + def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None: + return DistributionFunctionFactoryIndependentGaussians().create_dist_fn(envs) + class ActorFactoryDiscreteNet(ActorFactory): def __init__( @@ -229,6 +255,11 @@ def create_module(self, envs: Environments, device: TDevice) -> BaseActor: softmax_output=self.softmax_output, ).to(device) + def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None: + return DistributionFunctionFactoryCategorical( + is_probs_input=self.softmax_output, + ).create_dist_fn(envs) + class ActorFactoryTransientStorageDecorator(ActorFactory): """Wraps an actor factory, storing the most recently created actor instance such that it can be retrieved.""" @@ -254,6 +285,9 @@ def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.M self._actor_future.actor = module return module + def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None: + return self.actor_factory.create_dist_fn(envs) + class IntermediateModuleFactoryFromActorFactory(IntermediateModuleFactory): def __init__(self, actor_factory: ActorFactory): diff --git a/tianshou/highlevel/params/dist_fn.py b/tianshou/highlevel/params/dist_fn.py index c8d2aca9e..d28a2166c 100644 --- a/tianshou/highlevel/params/dist_fn.py +++ b/tianshou/highlevel/params/dist_fn.py @@ -4,7 +4,7 @@ import torch -from tianshou.highlevel.env import Environments, EnvType +from tianshou.highlevel.env import Environments from tianshou.policy.modelfree.pg import TDistFnDiscrete, TDistFnDiscrOrCont from tianshou.utils.string import ToStringMixin @@ -20,13 +20,29 @@ def create_dist_fn( class DistributionFunctionFactoryCategorical(DistributionFunctionFactory): + def __init__(self, is_probs_input: bool = True): + """ + :param is_probs_input: If True, the distribution function shall create a categorical distribution from a + tensor containing probabilities; otherwise the tensor is assumed to contain logits. + """ + self.is_probs_input = is_probs_input + def create_dist_fn(self, envs: Environments) -> TDistFnDiscrete: envs.get_type().assert_discrete(self) - return self._dist_fn + if self.is_probs_input: + return self._dist_fn_probs + else: + return self._dist_fn + + # NOTE: Do not move/rename because a reference to the function can appear in persisted policies + @staticmethod + def _dist_fn(logits: torch.Tensor) -> torch.distributions.Categorical: + return torch.distributions.Categorical(logits=logits) + # NOTE: Do not move/rename because a reference to the function can appear in persisted policies @staticmethod - def _dist_fn(p: torch.Tensor) -> torch.distributions.Categorical: - return torch.distributions.Categorical(logits=p) + def _dist_fn_probs(probs: torch.Tensor) -> torch.distributions.Categorical: + return torch.distributions.Categorical(probs=probs) class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory): @@ -34,18 +50,8 @@ def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont: envs.get_type().assert_continuous(self) return self._dist_fn + # NOTE: Do not move/rename because a reference to the function can appear in persisted policies @staticmethod def _dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> torch.distributions.Distribution: loc, scale = loc_scale return torch.distributions.Independent(torch.distributions.Normal(loc, scale), 1) - - -class DistributionFunctionFactoryDefault(DistributionFunctionFactory): - def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont: - match envs.get_type(): - case EnvType.DISCRETE: - return DistributionFunctionFactoryCategorical().create_dist_fn(envs) - case EnvType.CONTINUOUS: - return DistributionFunctionFactoryIndependentGaussians().create_dist_fn(envs) - case _: - raise ValueError(envs.get_type()) diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 79bfcf918..5f521d418 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -12,15 +12,11 @@ from tianshou.highlevel.module.module_opt import ModuleOpt from tianshou.highlevel.optim import OptimizerFactory from tianshou.highlevel.params.alpha import AutoAlphaFactory -from tianshou.highlevel.params.dist_fn import ( - DistributionFunctionFactory, - DistributionFunctionFactoryDefault, -) from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory from tianshou.highlevel.params.noise import NoiseFactory -from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont from tianshou.utils import MultipleLRSchedulers +from tianshou.utils.pickle import setstate from tianshou.utils.string import ToStringMixin @@ -209,15 +205,6 @@ def change_value(self, value: Any, data: ParamTransformerData) -> Any: return value -class ParamTransformerDistributionFunction(ParamTransformerChangeValue): - def change_value(self, value: Any, data: ParamTransformerData) -> Any: - if value == "default": - value = DistributionFunctionFactoryDefault().create_dist_fn(data.envs) - elif isinstance(value, DistributionFunctionFactory): - value = value.create_dist_fn(data.envs) - return value - - class ParamTransformerActionScaling(ParamTransformerChangeValue): def change_value(self, value: Any, data: ParamTransformerData) -> Any: if value == "default": @@ -322,20 +309,14 @@ class PGParams(Params, ParamsMixinActionScaling, ParamsMixinLearningRateWithSche whether to use deterministic action (the dist's mode) instead of stochastic one during evaluation. Does not affect training. """ - dist_fn: TDistFnDiscrOrCont | DistributionFunctionFactory | Literal["default"] = "default" - """ - This can either be a function which maps the model output to a torch distribution or a - factory for the creation of such a function. - When set to "default", a factory which creates Gaussian distributions from mean and standard - deviation will be used for the continuous case and which creates categorical distributions - for the discrete case (see :class:`DistributionFunctionFactoryDefault`) - """ + + def __setstate__(self, state: dict[str, Any]) -> None: + setstate(PGParams, self, state, removed_properties=["dist_fn"]) def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() transformers.extend(ParamsMixinActionScaling._get_param_transformers(self)) transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self)) - transformers.append(ParamTransformerDistributionFunction("dist_fn")) return transformers From 8114dd18e4dfeffe3c29c963d99b356a6bcb67b2 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 8 Aug 2024 10:02:46 +0200 Subject: [PATCH 2/2] Improve docstring of CriticFactoryReuseActor --- tianshou/highlevel/module/critic.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tianshou/highlevel/module/critic.py b/tianshou/highlevel/module/critic.py index 4eacef115..9e0fd070b 100644 --- a/tianshou/highlevel/module/critic.py +++ b/tianshou/highlevel/module/critic.py @@ -165,6 +165,12 @@ class CriticFactoryReuseActor(CriticFactory): """A critic factory which reuses the actor's preprocessing component. This class is for internal use in experiment builders only. + + Reuse of the actor network is supported through the concept of an actor future (:class:`ActorFuture`). + When the user declares that he wants to reuse the actor for the critic, we use this factory to support this, + but the actor does not exist yet. So the factory instead receives the future, which will eventually be filled + when the actor factory is called. When the creation method of this factory is eventually called, it can use the + then-filled actor to create the critic. """ def __init__(self, actor_future: ActorFuture):