diff --git a/test/base/test_stats.py b/test/base/test_stats.py index 9776374ba..6a17a8bcf 100644 --- a/test/base/test_stats.py +++ b/test/base/test_stats.py @@ -1,5 +1,12 @@ +from typing import cast + +import numpy as np import pytest +import torch +from torch.distributions import Categorical, Normal +from tianshou.data import Batch, CollectStats +from tianshou.data.collector import CollectStepBatchProtocol from tianshou.policy.base import TrainingStats, TrainingStatsWrapper @@ -47,3 +54,36 @@ def test_training_stats_wrapper() -> None: "loss_field", ), "Attribute `loss_field` not found in `wrapped_train_stats`." assert wrapped_train_stats.wrapped_stats.loss_field == wrapped_train_stats.loss_field == 13 + + @staticmethod + @pytest.mark.parametrize( + "act,dist", + ( + (np.array(1), np.array([1, 2, 3])), + (Categorical(probs=torch.tensor([0.5, 0.5])), Normal(torch.zeros(3), torch.ones(3))), + ), + ) + def test_collect_stats_update_at_step( + act: np.ndarray, + dist: torch.distributions.Distribution, + ) -> None: + step_batch = cast( + CollectStepBatchProtocol, + Batch( + info={}, + obs=np.array([1, 2, 3]), + obs_next=np.array([4, 5, 6]), + act=act, + rew=np.array(1.0), + done=np.array(False), + terminated=np.array(False), + ), + ) + stats = CollectStats() + stats.update_at_step_batch(step_batch, refresh_sequence_stats=True) + assert stats.lens_stat is not None + assert stats.lens_stat.mean == 1.0 + assert stats.pred_dist_std_array is not None + assert np.allclose(stats.pred_dist_std_array, dist.stddev) + assert stats.pred_dist_std_array_stat is not None + assert stats.pred_dist_std_array_stat[0].mean == dist.stddev[0].item() diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 2c615923d..cbf43a976 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -23,6 +23,7 @@ to_numpy, ) from tianshou.data.buffer.base import MalformedBufferError +from tianshou.data.stats import compute_dim_to_summary_stats from tianshou.data.types import ( ActBatchProtocol, DistBatchProtocol, @@ -115,10 +116,10 @@ class CollectStats(CollectStatsBase): """The collected episode lengths.""" lens_stat: SequenceSummaryStats | None = None """Stats of the collected episode lengths.""" - std_array: np.ndarray | None = None + pred_dist_std_array: np.ndarray | None = None """The standard deviations of the predicted distributions.""" - std_array_stat: SequenceSummaryStats | None = None - """Stats of the standard deviations of the predicted distributions.""" + pred_dist_std_array_stat: dict[int, SequenceSummaryStats] | None = None + """Stats of the standard deviations of the predicted distributions (maps action dim to stats)""" @classmethod def with_autogenerated_stats( @@ -152,10 +153,12 @@ def update_at_step_batch( self.n_collected_steps += len(step_batch) action_std = step_batch.dist.stddev if step_batch.dist is not None else None if action_std is not None: - if self.std_array is None: - self.std_array = to_numpy(action_std) + if self.pred_dist_std_array is None: + self.pred_dist_std_array = np.atleast_2d(to_numpy(action_std)) else: - self.std_array = np.concatenate((self.std_array, to_numpy(action_std))) + self.pred_dist_std_array = np.concatenate( + (self.pred_dist_std_array, to_numpy(action_std)), + ) if refresh_sequence_stats: self.refresh_std_array_stats() @@ -208,10 +211,11 @@ def refresh_len_stats(self) -> None: self.lens_stat = None def refresh_std_array_stats(self) -> None: - if self.std_array is not None and self.std_array.size > 0: - self.std_array_stat = SequenceSummaryStats.from_sequence(self.std_array) + if self.pred_dist_std_array is not None and self.pred_dist_std_array.size > 0: + # need to use .T because action dim supposed to be the first axis in compute_dim_to_summary_stats + self.pred_dist_std_array_stat = compute_dim_to_summary_stats(self.pred_dist_std_array.T) else: - self.std_array_stat = None + self.pred_dist_std_array_stat = None def refresh_all_sequence_stats(self) -> None: self.refresh_return_stats() diff --git a/tianshou/data/stats.py b/tianshou/data/stats.py index ed64a429d..11d64c017 100644 --- a/tianshou/data/stats.py +++ b/tianshou/data/stats.py @@ -1,3 +1,4 @@ +import logging from collections.abc import Sequence from dataclasses import dataclass from typing import TYPE_CHECKING, Optional @@ -10,6 +11,8 @@ from tianshou.data import CollectStats, CollectStatsBase from tianshou.policy.base import TrainingStats +log = logging.getLogger(__name__) + @dataclass(kw_only=True) class SequenceSummaryStats(DataclassPPrintMixin): @@ -24,6 +27,14 @@ class SequenceSummaryStats(DataclassPPrintMixin): def from_sequence(cls, sequence: Sequence[float | int] | np.ndarray) -> "SequenceSummaryStats": if len(sequence) == 0: return cls(mean=0.0, std=0.0, max=0.0, min=0.0) + + if hasattr(sequence, "shape") and len(sequence.shape) > 1: + log.warning( + f"Sequence has shape {sequence.shape}, but only 1D sequences are supported. " + "Stats will be computed from the flattened sequence. For computing stats " + "for each dimension consider using the function `compute_dim_to_summary_stats`.", + ) + return cls( mean=float(np.mean(sequence)), std=float(np.std(sequence)), @@ -32,6 +43,20 @@ def from_sequence(cls, sequence: Sequence[float | int] | np.ndarray) -> "Sequenc ) +def compute_dim_to_summary_stats( + arr: Sequence[Sequence[float]] | np.ndarray, +) -> dict[int, SequenceSummaryStats]: + """Compute summary statistics for each dimension of a sequence. + + :param arr: a 2-dim arr (or sequence of sequences) from which to compute the statistics. + :return: A dictionary of summary statistics for each dimension. + """ + stats = {} + for dim, seq in enumerate(arr): + stats[dim] = SequenceSummaryStats.from_sequence(seq) + return stats + + @dataclass(kw_only=True) class TimingStats(DataclassPPrintMixin): """A data structure for storing timing statistics."""