Skip to content

Commit

Permalink
CollectStats: better collection for std of actions (not flattening). …
Browse files Browse the repository at this point in the history
…Added tests, renamed entries
  • Loading branch information
Michael Panchenko committed Aug 24, 2024
1 parent 6f8648a commit 3a1de8c
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 9 deletions.
40 changes: 40 additions & 0 deletions test/base/test_stats.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
from typing import cast

import numpy as np
import pytest
import torch
from torch.distributions import Categorical, Normal

from tianshou.data import Batch, CollectStats
from tianshou.data.collector import CollectStepBatchProtocol
from tianshou.policy.base import TrainingStats, TrainingStatsWrapper


Expand Down Expand Up @@ -47,3 +54,36 @@ def test_training_stats_wrapper() -> None:
"loss_field",
), "Attribute `loss_field` not found in `wrapped_train_stats`."
assert wrapped_train_stats.wrapped_stats.loss_field == wrapped_train_stats.loss_field == 13

@staticmethod
@pytest.mark.parametrize(
"act,dist",
(
(np.array(1), np.array([1, 2, 3])),
(Categorical(probs=torch.tensor([0.5, 0.5])), Normal(torch.zeros(3), torch.ones(3))),
),
)
def test_collect_stats_update_at_step(
act: np.ndarray,
dist: torch.distributions.Distribution,
) -> None:
step_batch = cast(
CollectStepBatchProtocol,
Batch(
info={},
obs=np.array([1, 2, 3]),
obs_next=np.array([4, 5, 6]),
act=act,
rew=np.array(1.0),
done=np.array(False),
terminated=np.array(False),
),
)
stats = CollectStats()
stats.update_at_step_batch(step_batch, refresh_sequence_stats=True)
assert stats.lens_stat is not None
assert stats.lens_stat.mean == 1.0
assert stats.pred_dist_std_array is not None
assert np.allclose(stats.pred_dist_std_array, dist.stddev)
assert stats.pred_dist_std_array_stat is not None
assert stats.pred_dist_std_array_stat[0].mean == dist.stddev[0].item()
22 changes: 13 additions & 9 deletions tianshou/data/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
to_numpy,
)
from tianshou.data.buffer.base import MalformedBufferError
from tianshou.data.stats import compute_dim_to_summary_stats
from tianshou.data.types import (
ActBatchProtocol,
DistBatchProtocol,
Expand Down Expand Up @@ -115,10 +116,10 @@ class CollectStats(CollectStatsBase):
"""The collected episode lengths."""
lens_stat: SequenceSummaryStats | None = None
"""Stats of the collected episode lengths."""
std_array: np.ndarray | None = None
pred_dist_std_array: np.ndarray | None = None
"""The standard deviations of the predicted distributions."""
std_array_stat: SequenceSummaryStats | None = None
"""Stats of the standard deviations of the predicted distributions."""
pred_dist_std_array_stat: dict[int, SequenceSummaryStats] | None = None
"""Stats of the standard deviations of the predicted distributions (maps action dim to stats)"""

@classmethod
def with_autogenerated_stats(
Expand Down Expand Up @@ -152,10 +153,12 @@ def update_at_step_batch(
self.n_collected_steps += len(step_batch)
action_std = step_batch.dist.stddev if step_batch.dist is not None else None
if action_std is not None:
if self.std_array is None:
self.std_array = to_numpy(action_std)
if self.pred_dist_std_array is None:
self.pred_dist_std_array = np.atleast_2d(to_numpy(action_std))
else:
self.std_array = np.concatenate((self.std_array, to_numpy(action_std)))
self.pred_dist_std_array = np.concatenate(
(self.pred_dist_std_array, to_numpy(action_std)),
)
if refresh_sequence_stats:
self.refresh_std_array_stats()

Expand Down Expand Up @@ -208,10 +211,11 @@ def refresh_len_stats(self) -> None:
self.lens_stat = None

def refresh_std_array_stats(self) -> None:
if self.std_array is not None and self.std_array.size > 0:
self.std_array_stat = SequenceSummaryStats.from_sequence(self.std_array)
if self.pred_dist_std_array is not None and self.pred_dist_std_array.size > 0:
# need to use .T because action dim supposed to be the first axis in compute_dim_to_summary_stats
self.pred_dist_std_array_stat = compute_dim_to_summary_stats(self.pred_dist_std_array.T)
else:
self.std_array_stat = None
self.pred_dist_std_array_stat = None

def refresh_all_sequence_stats(self) -> None:
self.refresh_return_stats()
Expand Down
25 changes: 25 additions & 0 deletions tianshou/data/stats.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
Expand All @@ -10,6 +11,8 @@
from tianshou.data import CollectStats, CollectStatsBase
from tianshou.policy.base import TrainingStats

log = logging.getLogger(__name__)


@dataclass(kw_only=True)
class SequenceSummaryStats(DataclassPPrintMixin):
Expand All @@ -24,6 +27,14 @@ class SequenceSummaryStats(DataclassPPrintMixin):
def from_sequence(cls, sequence: Sequence[float | int] | np.ndarray) -> "SequenceSummaryStats":
if len(sequence) == 0:
return cls(mean=0.0, std=0.0, max=0.0, min=0.0)

if hasattr(sequence, "shape") and len(sequence.shape) > 1:
log.warning(
f"Sequence has shape {sequence.shape}, but only 1D sequences are supported. "
"Stats will be computed from the flattened sequence. For computing stats "
"for each dimension consider using the function `compute_dim_to_summary_stats`.",
)

return cls(
mean=float(np.mean(sequence)),
std=float(np.std(sequence)),
Expand All @@ -32,6 +43,20 @@ def from_sequence(cls, sequence: Sequence[float | int] | np.ndarray) -> "Sequenc
)


def compute_dim_to_summary_stats(
arr: Sequence[Sequence[float]] | np.ndarray,
) -> dict[int, SequenceSummaryStats]:
"""Compute summary statistics for each dimension of a sequence.
:param arr: a 2-dim arr (or sequence of sequences) from which to compute the statistics.
:return: A dictionary of summary statistics for each dimension.
"""
stats = {}
for dim, seq in enumerate(arr):
stats[dim] = SequenceSummaryStats.from_sequence(seq)
return stats


@dataclass(kw_only=True)
class TimingStats(DataclassPPrintMixin):
"""A data structure for storing timing statistics."""
Expand Down

0 comments on commit 3a1de8c

Please sign in to comment.