Skip to content

Commit

Permalink
RandomActionPolicy: fixes for discrete case, added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Panchenko committed Aug 17, 2024
1 parent c8ef74e commit dcf1b2e
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 3 deletions.
44 changes: 44 additions & 0 deletions test/base/test_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import torch
from torch.distributions import Categorical, Distribution, Independent, Normal

from tianshou.data import Batch
from tianshou.policy import BasePolicy, PPOPolicy
from tianshou.policy.base import RandomActionPolicy
from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.utils.net.discrete import Actor
Expand Down Expand Up @@ -77,3 +79,45 @@ def test_get_action(self, policy: PPOPolicy) -> None:
actions = [policy.compute_action(sample_obs) for _ in range(10)]
# check that the actions are the same in deterministic mode
assert len(set(map(_to_hashable, actions))) == 1

@staticmethod
def test_random_policy_discrete_actions() -> None:
action_space = gym.spaces.Discrete(3)
policy = RandomActionPolicy(action_space=action_space)

# forward of actor returns discrete probabilities, in compliance with the overall discrete actor
action_probs = policy.actor(np.zeros((10, 2)))[0]
assert np.allclose(action_probs, 1 / 3 * np.ones((10, 3)))

actions = []
for _ in range(10):
action = policy.compute_action(np.array([0]))
assert action_space.contains(action)
actions.append(action)

# not all actions are the same
assert len(set(actions)) > 1

# test batched forward
action_batch = policy(Batch(obs=np.zeros((10, 2))))
assert action_batch.act.shape == (10,)
assert len(set(action_batch.act.tolist())) > 1

@staticmethod
def test_random_policy_continuous_actions() -> None:
action_space = gym.spaces.Box(low=-1, high=1, shape=(3,))
policy = RandomActionPolicy(action_space=action_space)

actions = []
for _ in range(10):
action = policy.compute_action(np.array([0]))
assert action_space.contains(action)
actions.append(action)

# not all actions are the same
assert len(set(map(_to_hashable, actions))) > 1

# test batched forward
action_batch = policy(Batch(obs=np.zeros((10, 2))))
assert action_batch.act.shape == (10, 3)
assert len(set(map(_to_hashable, action_batch.act))) > 1
2 changes: 1 addition & 1 deletion tianshou/policy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ def forward(
state: dict | BatchProtocol | np.ndarray | None = None,
**kwargs: Any,
) -> ActStateBatchProtocol:
act, next_state = self.actor(batch.obs, state)
act, next_state = self.actor.compute_action_batch(batch.obs), state
return cast(ActStateBatchProtocol, Batch(act=act, state=next_state))

def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TrainingStats:
Expand Down
16 changes: 14 additions & 2 deletions tianshou/utils/net/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from gymnasium import spaces
from torch import nn

from tianshou.data.batch import Batch
from tianshou.data.batch import Batch, BatchProtocol
from tianshou.data.types import RecurrentStateBatch
from tianshou.utils.space_info import ActionSpaceInfo

Expand Down Expand Up @@ -661,9 +661,13 @@ def get_preprocess_net(self) -> nn.Module:
def get_output_dim(self) -> int:
return self.space_info.action_dim

@property
def is_discrete(self) -> bool:
return isinstance(self.action_space, spaces.Discrete)

def forward(
self,
obs: np.ndarray | torch.Tensor,
obs: np.ndarray | torch.Tensor | BatchProtocol,
state: Any | None = None,
info: dict[str, Any] | None = None,
) -> tuple[np.ndarray, Any | None]:
Expand All @@ -675,6 +679,14 @@ def forward(
action = 1 / self.action_space.n * np.ones((batch_size, self.action_space.n))
return action, state

def compute_action_batch(self, obs: np.ndarray | torch.Tensor | BatchProtocol) -> np.ndarray:
if self.is_discrete:
# Different from forward which returns discrete probabilities, see comment there
assert isinstance(self.action_space, spaces.Discrete) # for mypy
return np.random.randint(low=0, high=self.action_space.n, size=len(obs))
else:
return self.forward(obs)[0]


def getattr_with_matching_alt_value(obj: Any, attr_name: str, alt_value: T | None) -> T:
"""Gets the given attribute from the given object or takes the alternative value if it is not present.
Expand Down

0 comments on commit dcf1b2e

Please sign in to comment.