diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index b72ab5e96..1a1a0bf76 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -184,7 +184,7 @@ def create_trainer( max_epoch=sampling_config.num_epochs, step_per_epoch=sampling_config.step_per_epoch, repeat_per_collect=sampling_config.repeat_per_collect, - episode_per_test=sampling_config.num_test_envs, + episode_per_test=sampling_config.num_test_episodes_per_test_env, batch_size=sampling_config.batch_size, step_per_collect=sampling_config.step_per_collect, save_best_fn=policy_persistence.get_save_best_fn(world), @@ -228,7 +228,7 @@ def create_trainer( max_epoch=sampling_config.num_epochs, step_per_epoch=sampling_config.step_per_epoch, step_per_collect=sampling_config.step_per_collect, - episode_per_test=sampling_config.num_test_envs, + episode_per_test=sampling_config.num_test_episodes_per_test_env, batch_size=sampling_config.batch_size, save_best_fn=policy_persistence.get_save_best_fn(world), logger=world.logger, diff --git a/tianshou/highlevel/config.py b/tianshou/highlevel/config.py index 80e04769d..498247214 100644 --- a/tianshou/highlevel/config.py +++ b/tianshou/highlevel/config.py @@ -1,3 +1,4 @@ +import math import multiprocessing from dataclasses import dataclass @@ -16,7 +17,10 @@ class SamplingConfig(ToStringMixin): * collects environment steps/transitions (collection step), adding them to the (replay) buffer (see :attr:`step_per_collect`) - * performs one or more gradient updates (see :attr:`update_per_step`). + * performs one or more gradient updates (see :attr:`update_per_step`), + + and the test step collects :attr:`num_episodes_per_test` test episodes in order to evaluate + agent performance. The number of training steps in each epoch is indirectly determined by :attr:`step_per_epoch`: As many training steps will be performed as are required in @@ -49,6 +53,12 @@ class SamplingConfig(ToStringMixin): num_test_envs: int = 1 """the number of test environments to use""" + num_test_episodes: int = 1 + """the total number of episodes to collect in each test step (across all test environments). + This should be a multiple of the number of test environments; if it is not, the effective + number of episodes collected will be the nearest multiple (rounded up). + """ + buffer_size: int = 4096 """the total size of the sample/replay buffer, in which environment steps (transitions) are stored""" @@ -119,3 +129,8 @@ class SamplingConfig(ToStringMixin): def __post_init__(self) -> None: if self.num_train_envs == -1: self.num_train_envs = multiprocessing.cpu_count() + + @property + def num_test_episodes_per_test_env(self) -> int: + """:return: the number of episodes to collect per test environment in every test step""" + return math.ceil(self.num_test_episodes / self.num_test_envs)