Skip to content

Commit

Permalink
HL inferfaces: don't create train-test collectors if not training (#1208
Browse files Browse the repository at this point in the history
)

Minor fix/extension
  • Loading branch information
MischaPanch authored Aug 14, 2024
1 parent d82d29d commit b8f3156
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
26 changes: 17 additions & 9 deletions tianshou/highlevel/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from sensai.util.logging import datetime_tag
from sensai.util.string import ToStringMixin

from tianshou.data import Collector, InfoStats
from tianshou.data import BaseCollector, Collector, InfoStats
from tianshou.env import BaseVectorEnv
from tianshou.highlevel.agent import (
A2CAgentFactory,
Expand Down Expand Up @@ -110,13 +110,14 @@
from tianshou.policy import BasePolicy
from tianshou.utils import LazyLogger
from tianshou.utils.net.common import ModuleType
from tianshou.utils.print import DataclassPPrintMixin
from tianshou.utils.warning import deprecation

log = logging.getLogger(__name__)


@dataclass
class ExperimentConfig:
class ExperimentConfig(ToStringMixin, DataclassPPrintMixin):
"""Generic config for setting up the experiment, not RL or training specific."""

seed: int = 42
Expand Down Expand Up @@ -160,7 +161,7 @@ class ExperimentResult:
"""dataclass of results as returned by the trainer (if any)"""


class Experiment(ToStringMixin):
class Experiment(ToStringMixin, DataclassPPrintMixin):
"""Represents a reinforcement learning experiment.
An experiment is composed only of configuration and factory objects, which themselves
Expand Down Expand Up @@ -332,12 +333,16 @@ def create_experiment_world(
# create policy and collectors
log.info("Creating policy")
policy = self.agent_factory.create_policy(envs, self.config.device)

log.info("Creating collectors")
train_collector, test_collector = self.agent_factory.create_train_test_collector(
policy,
envs,
reset_collectors=reset_collectors,
)
train_collector: BaseCollector | None = None
test_collector: BaseCollector | None = None
if self.config.train:
train_collector, test_collector = self.agent_factory.create_train_test_collector(
policy,
envs,
reset_collectors=reset_collectors,
)

# create context object with all relevant instances (except trainer; added later)
world = World(
Expand Down Expand Up @@ -413,6 +418,10 @@ def run(
):
trainer_result: InfoStats | None = None
if self.config.train:
assert world.trainer is not None
assert world.train_collector is not None
assert world.test_collector is not None

# prefilling buffers with either random or current agent's actions
if self.sampling_config.start_timesteps > 0:
log.info(
Expand All @@ -425,7 +434,6 @@ def run(
)

log.info("Starting training")
assert world.trainer is not None
world.trainer.run()
if use_persistence:
world.logger.finalize()
Expand Down
6 changes: 3 additions & 3 deletions tianshou/highlevel/world.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
from tianshou.trainer import BaseTrainer


@dataclass
@dataclass(kw_only=True)
class World:
"""Container for instances and configuration items that are relevant to an experiment."""

envs: "Environments"
policy: "BasePolicy"
train_collector: "BaseCollector"
test_collector: "BaseCollector"
train_collector: Optional["BaseCollector"] = None
test_collector: Optional["BaseCollector"] = None
logger: "TLogger"
persist_directory: str
restore_directory: str | None
Expand Down

0 comments on commit b8f3156

Please sign in to comment.