Skip to content

Commit

Permalink
Allow to configure number of test episodes in high-level API
Browse files Browse the repository at this point in the history
  • Loading branch information
opcode81 committed Feb 14, 2024
1 parent 8742e36 commit bf39185
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
4 changes: 2 additions & 2 deletions tianshou/highlevel/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def create_trainer(
max_epoch=sampling_config.num_epochs,
step_per_epoch=sampling_config.step_per_epoch,
repeat_per_collect=sampling_config.repeat_per_collect,
episode_per_test=sampling_config.num_test_envs,
episode_per_test=sampling_config.num_test_episodes_per_test_env,
batch_size=sampling_config.batch_size,
step_per_collect=sampling_config.step_per_collect,
save_best_fn=policy_persistence.get_save_best_fn(world),
Expand Down Expand Up @@ -228,7 +228,7 @@ def create_trainer(
max_epoch=sampling_config.num_epochs,
step_per_epoch=sampling_config.step_per_epoch,
step_per_collect=sampling_config.step_per_collect,
episode_per_test=sampling_config.num_test_envs,
episode_per_test=sampling_config.num_test_episodes_per_test_env,
batch_size=sampling_config.batch_size,
save_best_fn=policy_persistence.get_save_best_fn(world),
logger=world.logger,
Expand Down
17 changes: 16 additions & 1 deletion tianshou/highlevel/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import multiprocessing
from dataclasses import dataclass

Expand All @@ -16,7 +17,10 @@ class SamplingConfig(ToStringMixin):
* collects environment steps/transitions (collection step), adding them to the (replay)
buffer (see :attr:`step_per_collect`)
* performs one or more gradient updates (see :attr:`update_per_step`).
* performs one or more gradient updates (see :attr:`update_per_step`),
and the test step collects :attr:`num_episodes_per_test` test episodes in order to evaluate
agent performance.
The number of training steps in each epoch is indirectly determined by
:attr:`step_per_epoch`: As many training steps will be performed as are required in
Expand Down Expand Up @@ -49,6 +53,12 @@ class SamplingConfig(ToStringMixin):
num_test_envs: int = 1
"""the number of test environments to use"""

num_test_episodes: int = 1
"""the total number of episodes to collect in each test step (across all test environments).
This should be a multiple of the number of test environments; if it is not, the effective
number of episodes collected will be the nearest multiple (rounded up).
"""

buffer_size: int = 4096
"""the total size of the sample/replay buffer, in which environment steps (transitions) are
stored"""
Expand Down Expand Up @@ -119,3 +129,8 @@ class SamplingConfig(ToStringMixin):
def __post_init__(self) -> None:
if self.num_train_envs == -1:
self.num_train_envs = multiprocessing.cpu_count()

@property
def num_test_episodes_per_test_env(self) -> int:
""":return: the number of episodes to collect per test environment in every test step"""
return math.ceil(self.num_test_episodes / self.num_test_envs)

0 comments on commit bf39185

Please sign in to comment.