Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

High-level API: Establish a strong link between the actor and the distribution function #1195

Merged
merged 2 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## Release 1.1.0

### Api Extensions
### Changes/Improvements
- `evaluation`: New package for repeating the same experiment with multiple seeds and aggregating the results. #1074 #1141 #1183
- `data`:
- `Batch`:
Expand Down Expand Up @@ -107,6 +107,11 @@ continuous and discrete cases. #1032
- `utils.net.common.Recurrent` now receives and returns a `RecurrentStateBatch` instead of a dict. #1077
- `AtariEnvFactory` constructor (in examples, so not really breaking) now requires explicit train and test seeds. #1074
- `EnvFactoryRegistered` now requires an explicit `test_seed` in the constructor. #1074
- `highlevel`:
- The parameter `dist_fn` has been removed from the parameter objects (`PGParams`, `A2CParams`, `PPOParams`, `NPGParams`, `TRPOParams`).
The correct distribution is now determined automatically based on the actor factory being used, avoiding the possibility of
misspecification. Persisted configurations/policies continue to work as expected, but code must not specify the `dist_fn` parameter.
#1194 #1195


### Tests
Expand Down
16 changes: 15 additions & 1 deletion examples/atari/atari_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
IntermediateModule,
IntermediateModuleFactory,
)
from tianshou.highlevel.params.dist_fn import DistributionFunctionFactoryCategorical
from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
from tianshou.utils.net.common import NetBase
from tianshou.utils.net.discrete import Actor, NoisyLinear

Expand Down Expand Up @@ -246,6 +248,8 @@ def forward(


class ActorFactoryAtariDQN(ActorFactory):
USE_SOFTMAX_OUTPUT = False

def __init__(
self,
scale_obs: bool = True,
Expand Down Expand Up @@ -274,7 +278,17 @@ def create_module(self, envs: Environments, device: TDevice) -> Actor:
)
if self.scale_obs:
net = scale_obs(net)
return Actor(net, envs.get_action_shape(), device=device, softmax_output=False).to(device)
return Actor(
net,
envs.get_action_shape(),
device=device,
softmax_output=self.USE_SOFTMAX_OUTPUT,
).to(device)

def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None:
return DistributionFunctionFactoryCategorical(
is_probs_input=self.USE_SOFTMAX_OUTPUT,
).create_dist_fn(envs)


class IntermediateModuleFactoryAtariDQN(IntermediateModuleFactory):
Expand Down
4 changes: 0 additions & 4 deletions examples/mujoco/mujoco_npg_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
ExperimentConfig,
NPGExperimentBuilder,
)
from tianshou.highlevel.params.dist_fn import (
DistributionFunctionFactoryIndependentGaussians,
)
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import NPGParams
from tianshou.utils import logging
Expand Down Expand Up @@ -78,7 +75,6 @@ def main(
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
if lr_decay
else None,
dist_fn=DistributionFunctionFactoryIndependentGaussians(),
),
)
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
Expand Down
4 changes: 0 additions & 4 deletions examples/mujoco/mujoco_ppo_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
ExperimentConfig,
PPOExperimentBuilder,
)
from tianshou.highlevel.params.dist_fn import (
DistributionFunctionFactoryIndependentGaussians,
)
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import PPOParams
from tianshou.utils import logging
Expand Down Expand Up @@ -88,7 +85,6 @@ def main(
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
if lr_decay
else None,
dist_fn=DistributionFunctionFactoryIndependentGaussians(),
),
)
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
Expand Down
4 changes: 0 additions & 4 deletions examples/mujoco/mujoco_ppo_hl_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@
PPOExperimentBuilder,
)
from tianshou.highlevel.logger import LoggerFactoryDefault
from tianshou.highlevel.params.dist_fn import (
DistributionFunctionFactoryIndependentGaussians,
)
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import PPOParams
from tianshou.utils import logging
Expand Down Expand Up @@ -115,7 +112,6 @@ def main(
recompute_advantage=True,
lr=3e-4,
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config),
dist_fn=DistributionFunctionFactoryIndependentGaussians(),
),
)
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
Expand Down
4 changes: 0 additions & 4 deletions examples/mujoco/mujoco_trpo_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
ExperimentConfig,
TRPOExperimentBuilder,
)
from tianshou.highlevel.params.dist_fn import (
DistributionFunctionFactoryIndependentGaussians,
)
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import TRPOParams
from tianshou.utils import logging
Expand Down Expand Up @@ -82,7 +79,6 @@ def main(
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
if lr_decay
else None,
dist_fn=DistributionFunctionFactoryIndependentGaussians(),
),
)
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
Expand Down
1 change: 1 addition & 0 deletions test/highlevel/test_experiment_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def test_experiment_builder_continuous_default_params(builder_cls: type[Experime
@pytest.mark.parametrize(
"builder_cls",
[
PGExperimentBuilder,
PPOExperimentBuilder,
A2CExperimentBuilder,
DQNExperimentBuilder,
Expand Down
4 changes: 4 additions & 0 deletions tianshou/highlevel/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,14 @@ def _create_policy(self, envs: Environments, device: TDevice) -> PGPolicy:
optim_factory=self.optim_factory,
),
)
dist_fn = self.actor_factory.create_dist_fn(envs)
assert dist_fn is not None
return PGPolicy(
actor=actor.module,
optim=actor.optim,
action_space=envs.get_action_space(),
observation_space=envs.get_observation_space(),
dist_fn=dist_fn,
**kwargs,
)

Expand Down Expand Up @@ -333,6 +336,7 @@ def _create_kwargs(self, envs: Environments, device: TDevice) -> dict[str, Any]:
kwargs["critic"] = actor_critic.critic
kwargs["optim"] = actor_critic.optim
kwargs["action_space"] = envs.get_action_space()
kwargs["dist_fn"] = self.actor_factory.create_dist_fn(envs)
return kwargs

def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
Expand Down
42 changes: 38 additions & 4 deletions tianshou/highlevel/module/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
)
from tianshou.highlevel.module.module_opt import ModuleOpt
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.highlevel.params.dist_fn import (
DistributionFunctionFactoryCategorical,
DistributionFunctionFactoryIndependentGaussians,
)
from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
from tianshou.utils.net import continuous, discrete
from tianshou.utils.net.common import BaseActor, ModuleType, Net
from tianshou.utils.string import ToStringMixin
Expand Down Expand Up @@ -47,6 +52,14 @@ class ActorFactory(ModuleFactory, ToStringMixin, ABC):
def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module:
pass

@abstractmethod
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None:
"""
:param envs: the environments
:return: the distribution function, which converts the actor's output into a distribution, or None
if the actor does not output distribution parameters
"""

def create_module_opt(
self,
envs: Environments,
Expand All @@ -70,7 +83,7 @@ def create_module_opt(
def _init_linear(actor: torch.nn.Module) -> None:
"""Initializes linear layers of an actor module using default mechanisms.

:param module: the actor module.
:param actor: the actor module.
"""
init_linear_orthogonal(actor)
if hasattr(actor, "mu"):
Expand Down Expand Up @@ -104,7 +117,7 @@ def __init__(
self.hidden_activation = hidden_activation
self.discrete_softmax = discrete_softmax

def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
def _create_factory(self, envs: Environments) -> ActorFactory:
env_type = envs.get_type()
factory: ActorFactoryContinuousDeterministicNet | ActorFactoryContinuousGaussianNet | ActorFactoryDiscreteNet
if env_type == EnvType.CONTINUOUS:
Expand All @@ -125,15 +138,22 @@ def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
raise ValueError("Continuous action spaces are not supported by the algorithm")
case _:
raise ValueError(self.continuous_actor_type)
return factory.create_module(envs, device)
elif env_type == EnvType.DISCRETE:
factory = ActorFactoryDiscreteNet(
self.DEFAULT_HIDDEN_SIZES,
softmax_output=self.discrete_softmax,
)
return factory.create_module(envs, device)
else:
raise ValueError(f"{env_type} not supported")
return factory

def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module:
factory = self._create_factory(envs)
return factory.create_module(envs, device)

def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None:
factory = self._create_factory(envs)
return factory.create_dist_fn(envs)


class ActorFactoryContinuous(ActorFactory, ABC):
Expand All @@ -159,6 +179,9 @@ def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
device=device,
).to(device)

def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None:
opcode81 marked this conversation as resolved.
Show resolved Hide resolved
return None


class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous):
def __init__(
Expand Down Expand Up @@ -202,6 +225,9 @@ def create_module(self, envs: Environments, device: TDevice) -> BaseActor:

return actor

def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None:
MischaPanch marked this conversation as resolved.
Show resolved Hide resolved
return DistributionFunctionFactoryIndependentGaussians().create_dist_fn(envs)


class ActorFactoryDiscreteNet(ActorFactory):
def __init__(
Expand Down Expand Up @@ -229,6 +255,11 @@ def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
softmax_output=self.softmax_output,
).to(device)

def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None:
return DistributionFunctionFactoryCategorical(
is_probs_input=self.softmax_output,
).create_dist_fn(envs)


class ActorFactoryTransientStorageDecorator(ActorFactory):
"""Wraps an actor factory, storing the most recently created actor instance such that it can be retrieved."""
Expand All @@ -254,6 +285,9 @@ def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.M
self._actor_future.actor = module
return module

def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None:
return self.actor_factory.create_dist_fn(envs)


class IntermediateModuleFactoryFromActorFactory(IntermediateModuleFactory):
def __init__(self, actor_factory: ActorFactory):
Expand Down
6 changes: 6 additions & 0 deletions tianshou/highlevel/module/critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,12 @@ class CriticFactoryReuseActor(CriticFactory):
"""A critic factory which reuses the actor's preprocessing component.

This class is for internal use in experiment builders only.

Reuse of the actor network is supported through the concept of an actor future (:class:`ActorFuture`).
When the user declares that he wants to reuse the actor for the critic, we use this factory to support this,
but the actor does not exist yet. So the factory instead receives the future, which will eventually be filled
when the actor factory is called. When the creation method of this factory is eventually called, it can use the
then-filled actor to create the critic.
"""

def __init__(self, actor_future: ActorFuture):
Expand Down
36 changes: 21 additions & 15 deletions tianshou/highlevel/params/dist_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from tianshou.highlevel.env import Environments, EnvType
from tianshou.highlevel.env import Environments
from tianshou.policy.modelfree.pg import TDistFnDiscrete, TDistFnDiscrOrCont
from tianshou.utils.string import ToStringMixin

Expand All @@ -20,32 +20,38 @@ def create_dist_fn(


class DistributionFunctionFactoryCategorical(DistributionFunctionFactory):
def __init__(self, is_probs_input: bool = True):
"""
:param is_probs_input: If True, the distribution function shall create a categorical distribution from a
tensor containing probabilities; otherwise the tensor is assumed to contain logits.
"""
self.is_probs_input = is_probs_input

def create_dist_fn(self, envs: Environments) -> TDistFnDiscrete:
envs.get_type().assert_discrete(self)
return self._dist_fn
if self.is_probs_input:
return self._dist_fn_probs
else:
return self._dist_fn

# NOTE: Do not move/rename because a reference to the function can appear in persisted policies
@staticmethod
def _dist_fn(logits: torch.Tensor) -> torch.distributions.Categorical:
return torch.distributions.Categorical(logits=logits)

# NOTE: Do not move/rename because a reference to the function can appear in persisted policies
@staticmethod
def _dist_fn(p: torch.Tensor) -> torch.distributions.Categorical:
return torch.distributions.Categorical(logits=p)
def _dist_fn_probs(probs: torch.Tensor) -> torch.distributions.Categorical:
return torch.distributions.Categorical(probs=probs)


class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont:
envs.get_type().assert_continuous(self)
return self._dist_fn

# NOTE: Do not move/rename because a reference to the function can appear in persisted policies
@staticmethod
def _dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> torch.distributions.Distribution:
loc, scale = loc_scale
return torch.distributions.Independent(torch.distributions.Normal(loc, scale), 1)


class DistributionFunctionFactoryDefault(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont:
match envs.get_type():
case EnvType.DISCRETE:
return DistributionFunctionFactoryCategorical().create_dist_fn(envs)
case EnvType.CONTINUOUS:
return DistributionFunctionFactoryIndependentGaussians().create_dist_fn(envs)
case _:
raise ValueError(envs.get_type())
Loading
Loading