Skip to content

Commit

Permalink
Merge branch 'master' into feature/random-agent
Browse files Browse the repository at this point in the history
  • Loading branch information
MischaPanch authored Aug 18, 2024
2 parents dcf1b2e + 616e6a9 commit 5ae36b4
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 41 deletions.
6 changes: 3 additions & 3 deletions docs/02_notebooks/L3_Vectorized__Environment.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,10 @@
"* ShmemVectorEnv: use share memory instead of pipe based on SubprocVectorEnv;\n",
"* RayVectorEnv: use Ray for concurrent activities and is currently the only choice for parallel simulation in a cluster with multiple machines.\n",
"\n",
"Check the [documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.env.html) for details.\n",
"Check the [documentation](https://tianshou.org/en/master/03_api/env/venvs.html) for details.\n",
"\n",
"### Difference between synchronous and asynchronous mode (How to choose?)\n",
"Explanation can be found at the [Parallel Sampling](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#parallel-sampling) tutorial."
"Explanation can be found at the [Parallel Sampling](https://tianshou.org/en/master/01_tutorials/07_cheatsheet.html#parallel-sampling) tutorial."
]
}
],
Expand All @@ -223,7 +223,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.11.7"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion docs/02_notebooks/L5_Collector.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@
},
"source": [
"## Further Reading\n",
"The above collector actually collects 52 data at a time because 52 % 4 = 0. There is one asynchronous collector which allows you collect exactly 50 steps. Check the [documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.data.html#asynccollector) for details."
"The above collector actually collects 52 data at a time because 52 % 4 = 0. There is one asynchronous collector which allows you collect exactly 50 steps. Check the [documentation](https://tianshou.org/en/master/03_api/data/collector.html#tianshou.data.collector.AsyncCollector) for details."
]
}
],
Expand Down
26 changes: 13 additions & 13 deletions docs/02_notebooks/L6_Trainer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,12 @@
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-06T15:34:02.969675Z",
"start_time": "2024-05-06T15:34:00.747309Z"
},
"editable": true,
"id": "do-xZ-8B7nVH",
"slideshow": {
Expand All @@ -63,12 +68,9 @@
"tags": [
"hide-cell",
"remove-output"
],
"ExecuteTime": {
"end_time": "2024-05-06T15:34:02.969675Z",
"start_time": "2024-05-06T15:34:00.747309Z"
}
]
},
"outputs": [],
"source": [
"%%capture\n",
"\n",
Expand All @@ -82,18 +84,18 @@
"from tianshou.utils.net.common import Net\n",
"from tianshou.utils.net.discrete import Actor\n",
"from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode"
],
"outputs": [],
"execution_count": 1
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-06T15:34:07.536452Z",
"start_time": "2024-05-06T15:34:03.636670Z"
}
},
"outputs": [],
"source": [
"train_env_num = 4\n",
"buffer_size = (\n",
Expand Down Expand Up @@ -131,9 +133,7 @@
"replayBuffer = VectorReplayBuffer(buffer_size, train_env_num)\n",
"test_collector = Collector(policy, test_envs)\n",
"train_collector = Collector(policy, train_envs, replayBuffer)"
],
"outputs": [],
"execution_count": 2
]
},
{
"cell_type": "markdown",
Expand Down Expand Up @@ -252,10 +252,10 @@
"source": [
"## Further Reading\n",
"### Logger usages\n",
"Tianshou provides experiment loggers that are both tensorboard- and wandb-compatible. It also has a BaseLogger Class which allows you to self-define your own logger. Check the [documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.utils.html#tianshou.utils.BaseLogger) for details.\n",
"Tianshou provides experiment loggers that are both tensorboard- and wandb-compatible. It also has a BaseLogger Class which allows you to self-define your own logger. Check the [documentation](https://tianshou.org/en/master/03_api/utils/logger/base.html#tianshou.utils.logger.base.BaseLogger) for details.\n",
"\n",
"### Learn more about the APIs of Trainers\n",
"[documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.trainer.html)"
"[documentation](https://tianshou.org/en/master/03_api/trainer/index.html)"
]
}
],
Expand Down
26 changes: 17 additions & 9 deletions tianshou/highlevel/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from sensai.util.logging import datetime_tag
from sensai.util.string import ToStringMixin

from tianshou.data import Collector, InfoStats
from tianshou.data import BaseCollector, Collector, InfoStats
from tianshou.env import BaseVectorEnv
from tianshou.highlevel.agent import (
A2CAgentFactory,
Expand Down Expand Up @@ -111,13 +111,14 @@
from tianshou.policy import BasePolicy
from tianshou.utils import LazyLogger
from tianshou.utils.net.common import ModuleType
from tianshou.utils.print import DataclassPPrintMixin
from tianshou.utils.warning import deprecation

log = logging.getLogger(__name__)


@dataclass
class ExperimentConfig:
class ExperimentConfig(ToStringMixin, DataclassPPrintMixin):
"""Generic config for setting up the experiment, not RL or training specific."""

seed: int = 42
Expand Down Expand Up @@ -161,7 +162,7 @@ class ExperimentResult:
"""dataclass of results as returned by the trainer (if any)"""


class Experiment(ToStringMixin):
class Experiment(ToStringMixin, DataclassPPrintMixin):
"""Represents a reinforcement learning experiment.
An experiment is composed only of configuration and factory objects, which themselves
Expand Down Expand Up @@ -333,12 +334,16 @@ def create_experiment_world(
# create policy and collectors
log.info("Creating policy")
policy = self.agent_factory.create_policy(envs, self.config.device)

log.info("Creating collectors")
train_collector, test_collector = self.agent_factory.create_train_test_collector(
policy,
envs,
reset_collectors=reset_collectors,
)
train_collector: BaseCollector | None = None
test_collector: BaseCollector | None = None
if self.config.train:
train_collector, test_collector = self.agent_factory.create_train_test_collector(
policy,
envs,
reset_collectors=reset_collectors,
)

# create context object with all relevant instances (except trainer; added later)
world = World(
Expand Down Expand Up @@ -414,6 +419,10 @@ def run(
):
trainer_result: InfoStats | None = None
if self.config.train:
assert world.trainer is not None
assert world.train_collector is not None
assert world.test_collector is not None

# prefilling buffers with either random or current agent's actions
if self.sampling_config.start_timesteps > 0:
log.info(
Expand All @@ -426,7 +435,6 @@ def run(
)

log.info("Starting training")
assert world.trainer is not None
world.trainer.run()
if use_persistence:
world.logger.finalize()
Expand Down
6 changes: 3 additions & 3 deletions tianshou/highlevel/world.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
from tianshou.trainer import BaseTrainer


@dataclass
@dataclass(kw_only=True)
class World:
"""Container for instances and configuration items that are relevant to an experiment."""

envs: "Environments"
policy: "BasePolicy"
train_collector: "BaseCollector"
test_collector: "BaseCollector"
train_collector: Optional["BaseCollector"] = None
test_collector: Optional["BaseCollector"] = None
logger: "TLogger"
persist_directory: str
restore_directory: str | None
Expand Down
19 changes: 7 additions & 12 deletions tianshou/trainer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,19 +406,14 @@ def test_step(self) -> tuple[CollectStats, bool]:
self.best_reward_std = rew_std
if self.save_best_fn:
self.save_best_fn(self.policy)
cur_info, best_info = "", ""
if score != rew:
# use custom score calculater
log_msg = (
f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, score: {score:.6f},"
f" best_reward: {self.best_reward:.6f} ± "
f"{self.best_reward_std:.6f}, score: {self.best_score:.6f} in #{self.best_epoch}"
)
else:
log_msg = (
f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f},"
f" best_reward: {self.best_reward:.6f} ± "
f"{self.best_reward_std:.6f} in #{self.best_epoch}"
)
cur_info, best_info = f", score: {score: .6f}", f", best_score: {self.best_score:.6f}"
log_msg = (
f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f},{cur_info}"
f" best_reward: {self.best_reward:.6f} ± "
f"{self.best_reward_std:.6f}{best_info} in #{self.best_epoch}"
)
log.info(log_msg)
if self.verbose:
print(log_msg, flush=True)
Expand Down

0 comments on commit 5ae36b4

Please sign in to comment.