Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for using random agents, improvements in CollectStats #1207

Merged
merged 13 commits into from
Aug 26, 2024
Merged
6 changes: 3 additions & 3 deletions docs/01_tutorials/04_tictactoe.rst
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ Two Random Agents

.. Figure:: ../_static/images/marl.png

Tianshou already provides some builtin classes for multi-agent learning. You can check out the API documentation for details. Here we use :class:`~tianshou.policy.RandomPolicy` and :class:`~tianshou.policy.MultiAgentPolicyManager`. The figure on the right gives an intuitive explanation.
Tianshou already provides some builtin classes for multi-agent learning. You can check out the API documentation for details. Here we use :class:`~tianshou.policy.MARLRandomPolicy` and :class:`~tianshou.policy.MultiAgentPolicyManager`. The figure on the right gives an intuitive explanation.

::

Expand Down Expand Up @@ -202,7 +202,7 @@ So let's start to train our Tic-Tac-Toe agent! First, import some required modul
BasePolicy,
DQNPolicy,
MultiAgentPolicyManager,
RandomPolicy,
MARLRandomPolicy,
)
from tianshou.trainer import OffpolicyTrainer
from tianshou.utils import TensorboardLogger
Expand Down Expand Up @@ -286,7 +286,7 @@ The following ``get_agents`` function returns agents and their optimizers from e

- The action model we use is an instance of :class:`~tianshou.utils.net.common.Net`, essentially a multi-layer perceptron with the ReLU activation function;
- The network model is passed to a :class:`~tianshou.policy.DQNPolicy`, where actions are selected according to both the action mask and their Q-values;
- The opponent can be either a random agent :class:`~tianshou.policy.RandomPolicy` that randomly chooses an action from legal actions, or it can be a pre-trained :class:`~tianshou.policy.DQNPolicy` allowing learned agents to play with themselves.
- The opponent can be either a random agent :class:`~tianshou.policy.MARLRandomPolicy` that randomly chooses an action from legal actions, or it can be a pre-trained :class:`~tianshou.policy.DQNPolicy` allowing learned agents to play with themselves.

Both agents are passed to :class:`~tianshou.policy.MultiAgentPolicyManager`, which is responsible to call the correct agent according to the ``agent_id`` in the observation. :class:`~tianshou.policy.MultiAgentPolicyManager` also dispatches data to each agent according to ``agent_id``, so that each agent seems to play with a virtual single-agent environment.

Expand Down
2 changes: 2 additions & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -288,3 +288,5 @@ monte
carlo
subclass
subclassing
dist
dists
60 changes: 59 additions & 1 deletion test/base/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
import pytest
import torch
from deepdiff import DeepDiff
from torch.distributions import Distribution, Independent, Normal
from torch.distributions.categorical import Categorical

from tianshou.data import Batch, to_numpy, to_torch
from tianshou.data.batch import IndexType, get_sliced_dist
from tianshou.data.batch import IndexType, dist_to_atleast_2d, get_sliced_dist


def test_batch() -> None:
Expand Down Expand Up @@ -766,6 +767,63 @@ def test_batch_over_batch_to_torch() -> None:
assert batch.b.d.dtype == torch.float32
assert batch.b.e.dtype == torch.float32

@staticmethod
@pytest.mark.parametrize(
"dist, expected_batch_shape",
[
(Categorical(probs=torch.tensor([0.3, 0.7])), (1,)),
(Categorical(probs=torch.tensor([[0.3, 0.7], [0.4, 0.6]])), (2,)),
(Normal(loc=torch.tensor(0.0), scale=torch.tensor(1.0)), (1,)),
(Normal(loc=torch.tensor([0.0, 1.0]), scale=torch.tensor([1.0, 2.0])), (2,)),
(Independent(Normal(loc=torch.tensor(0.0), scale=torch.tensor(1.0)), 0), (1,)),
(
Independent(
Normal(loc=torch.tensor([0.0, 1.0]), scale=torch.tensor([1.0, 2.0])),
0,
),
(2,),
),
],
)
def test_dist_to_atleast_2d(dist: Distribution, expected_batch_shape: tuple[int]) -> None:
result = dist_to_atleast_2d(dist)
assert result.batch_shape == expected_batch_shape

# Additionally check that the parameters are correctly transformed
if isinstance(dist, Categorical):
assert isinstance(result, Categorical)
assert result.probs.shape[:-1] == expected_batch_shape
elif isinstance(dist, Normal):
assert isinstance(result, Normal)
assert result.loc.shape == expected_batch_shape
assert result.scale.shape == expected_batch_shape
elif isinstance(dist, Independent):
assert isinstance(result, Independent)
assert result.base_dist.batch_shape == expected_batch_shape

@staticmethod
@pytest.mark.parametrize(
"dist",
[
Categorical(probs=torch.tensor([0.3, 0.7])),
Normal(loc=torch.tensor(0.0), scale=torch.tensor(1.0)),
Independent(Normal(loc=torch.tensor(0.0), scale=torch.tensor(1.0)), 0),
],
)
def test_dist_to_atleast_2d_idempotent(dist: Distribution) -> None:
result1 = dist_to_atleast_2d(dist)
result2 = dist_to_atleast_2d(result1)
assert result1 == result2

@staticmethod
def test_batch_to_atleast_2d() -> None:
scalar_batch = Batch(a=1, b=2, dist=Categorical(probs=torch.ones(3)))
assert scalar_batch.dist.batch_shape == ()
assert scalar_batch.a.shape == scalar_batch.b.shape == ()
scalar_batch_2d = scalar_batch.to_at_least_2d()
assert scalar_batch_2d.dist.batch_shape == (1,)
assert scalar_batch_2d.a.shape == scalar_batch_2d.b.shape == (1, 1)


class TestAssignment:
@staticmethod
Expand Down
45 changes: 44 additions & 1 deletion test/base/test_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import torch
from torch.distributions import Categorical, Distribution, Independent, Normal

from tianshou.data import Batch
from tianshou.policy import BasePolicy, PPOPolicy
from tianshou.policy.base import episode_mc_return_to_go
from tianshou.policy.base import RandomActionPolicy, episode_mc_return_to_go
from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.utils.net.discrete import Actor
Expand Down Expand Up @@ -85,3 +86,45 @@ def test_get_action(self, policy: PPOPolicy) -> None:
actions = [policy.compute_action(sample_obs) for _ in range(10)]
# check that the actions are the same in deterministic mode
assert len(set(map(_to_hashable, actions))) == 1

@staticmethod
def test_random_policy_discrete_actions() -> None:
action_space = gym.spaces.Discrete(3)
policy = RandomActionPolicy(action_space=action_space)

# forward of actor returns discrete probabilities, in compliance with the overall discrete actor
action_probs = policy.actor(np.zeros((10, 2)))[0]
assert np.allclose(action_probs, 1 / 3 * np.ones((10, 3)))

actions = []
for _ in range(10):
action = policy.compute_action(np.array([0]))
assert action_space.contains(action)
actions.append(action)

# not all actions are the same
assert len(set(actions)) > 1

# test batched forward
action_batch = policy(Batch(obs=np.zeros((10, 2))))
assert action_batch.act.shape == (10,)
assert len(set(action_batch.act.tolist())) > 1

@staticmethod
def test_random_policy_continuous_actions() -> None:
action_space = gym.spaces.Box(low=-1, high=1, shape=(3,))
policy = RandomActionPolicy(action_space=action_space)

actions = []
for _ in range(10):
action = policy.compute_action(np.array([0]))
assert action_space.contains(action)
actions.append(action)

# not all actions are the same
assert len(set(map(_to_hashable, actions))) > 1

# test batched forward
action_batch = policy(Batch(obs=np.zeros((10, 2))))
assert action_batch.act.shape == (10, 3)
assert len(set(map(_to_hashable, action_batch.act))) > 1
42 changes: 42 additions & 0 deletions test/base/test_stats.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
from typing import cast

import numpy as np
import pytest
import torch
from torch.distributions import Categorical, Normal

from tianshou.data import Batch, CollectStats
from tianshou.data.collector import CollectStepBatchProtocol, get_stddev_from_dist
from tianshou.policy.base import TrainingStats, TrainingStatsWrapper


Expand Down Expand Up @@ -47,3 +54,38 @@ def test_training_stats_wrapper() -> None:
"loss_field",
), "Attribute `loss_field` not found in `wrapped_train_stats`."
assert wrapped_train_stats.wrapped_stats.loss_field == wrapped_train_stats.loss_field == 13

@staticmethod
@pytest.mark.parametrize(
"act,dist",
(
(np.array(1), Categorical(probs=torch.tensor([0.5, 0.5]))),
(np.array([1, 2, 3]), Normal(torch.zeros(3), torch.ones(3))),
),
)
def test_collect_stats_update_at_step(
act: np.ndarray,
dist: torch.distributions.Distribution,
) -> None:
step_batch = cast(
CollectStepBatchProtocol,
Batch(
info={},
obs=np.array([1, 2, 3]),
obs_next=np.array([4, 5, 6]),
act=act,
rew=np.array(1.0),
done=np.array(False),
terminated=np.array(False),
dist=dist,
).to_at_least_2d(),
)
stats = CollectStats()
for _ in range(10):
stats.update_at_step_batch(step_batch)
stats.refresh_all_sequence_stats()
assert stats.n_collected_steps == 10
assert stats.pred_dist_std_array is not None
assert np.allclose(stats.pred_dist_std_array, get_stddev_from_dist(dist))
assert stats.pred_dist_std_array_stat is not None
assert stats.pred_dist_std_array_stat[0].mean == get_stddev_from_dist(dist)[0].item()
9 changes: 7 additions & 2 deletions test/pettingzoo/tic_tac_toe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
from tianshou.data.stats import InfoStats
from tianshou.env import DummyVectorEnv
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager, RandomPolicy
from tianshou.policy import (
BasePolicy,
DQNPolicy,
MARLRandomPolicy,
MultiAgentPolicyManager,
)
from tianshou.trainer import OffpolicyTrainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
Expand Down Expand Up @@ -131,7 +136,7 @@ def get_agents(
agent_opponent = deepcopy(agent_learn)
agent_opponent.load_state_dict(torch.load(args.opponent_path))
else:
agent_opponent = RandomPolicy(action_space=env.action_space)
agent_opponent = MARLRandomPolicy(action_space=env.action_space)

if args.agent_id == 1:
agents = [agent_learn, agent_opponent]
Expand Down
39 changes: 38 additions & 1 deletion tianshou/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,23 @@ def get_len_of_dist(dist: Distribution) -> int:
return dist.batch_shape[0]


def dist_to_atleast_2d(dist: TDistribution) -> TDistribution:
"""Convert a distribution to at least 2D, such that the `batch_shape` attribute has a len of at least 1."""
if len(dist.batch_shape) > 0:
return dist
if isinstance(dist, Categorical):
return Categorical(probs=dist.probs.unsqueeze(0)) # type: ignore[return-value]
elif isinstance(dist, Normal):
return Normal(loc=dist.loc.unsqueeze(0), scale=dist.scale.unsqueeze(0)) # type: ignore[return-value]
elif isinstance(dist, Independent):
return Independent(
dist_to_atleast_2d(dist.base_dist),
dist.reinterpreted_batch_ndims,
) # type: ignore[return-value]
else:
raise NotImplementedError(f"Unsupported distribution for conversion to 2D: {type(dist)}")


# Note: This is implemented as a protocol because the interface
# of Batch is always extended by adding new fields. Having a hierarchy of
# protocols building off this one allows for type safety and IDE support despite
Expand Down Expand Up @@ -602,6 +619,14 @@ def get(self, key: str, default: Any | None = None) -> Any:
def pop(self, key: str, default: Any | None = None) -> Any:
raise ProtocolCalledException

def to_at_least_2d(self) -> Self:
"""Ensures that all arrays and dists in the batch have at least 2 dimensions.

This is useful for ensuring that all arrays in the batch can be concatenated
along a new axis.
"""
raise ProtocolCalledException


class Batch(BatchProtocol):
"""See :class:`~tianshou.data.batch.BatchProtocol`."""
Expand Down Expand Up @@ -1160,7 +1185,7 @@ def __len__(self) -> int:
if isinstance(obj, Distribution):
lens.append(get_len_of_dist(obj))
continue
raise TypeError(f"Entry for {key} in {self} is {obj}has no len()")
raise TypeError(f"Entry for {key} in {self} is {obj} has no len()")
if not lens:
return 0
return min(lens)
Expand Down Expand Up @@ -1326,6 +1351,18 @@ def replace_empty_batches_by_none(self) -> None:
else:
val.replace_empty_batches_by_none()

def to_at_least_2d(self) -> Self:
"""Ensures that all arrays and dists in the batch have at least 2 dimensions.

This is useful for ensuring that all arrays in the batch can be concatenated
along a new axis.
"""
result = self.apply_values_transform(np.atleast_2d, inplace=False)
for key, val in self.items():
if isinstance(val, Distribution):
result[key] = dist_to_atleast_2d(val)
return result


def _apply_batch_values_func_recursively(
batch: TBatch,
Expand Down
Loading
Loading