Skip to content

Commit

Permalink
Merge branch 'master' into CI/add-publish-workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
MischaPanch authored Mar 4, 2024
2 parents eb99da3 + fdb69f1 commit 3ae2215
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 26 deletions.
39 changes: 24 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

[![PyPI](https://img.shields.io/pypi/v/tianshou)](https://pypi.org/project/tianshou/) [![Conda](https://img.shields.io/conda/vn/conda-forge/tianshou)](https://github.com/conda-forge/tianshou-feedstock) [![Read the Docs](https://img.shields.io/readthedocs/tianshou)](https://tianshou.readthedocs.io/en/master) [![Read the Docs](https://img.shields.io/readthedocs/tianshou-docs-zh-cn?label=%E4%B8%AD%E6%96%87%E6%96%87%E6%A1%A3)](https://tianshou.readthedocs.io/zh/master/) [![Unittest](https://github.com/thu-ml/tianshou/actions/workflows/pytest.yml/badge.svg)](https://github.com/thu-ml/tianshou/actions) [![codecov](https://img.shields.io/codecov/c/gh/thu-ml/tianshou)](https://codecov.io/gh/thu-ml/tianshou) [![GitHub issues](https://img.shields.io/github/issues/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/issues) [![GitHub stars](https://img.shields.io/github/stars/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/stargazers) [![GitHub forks](https://img.shields.io/github/forks/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/network) [![GitHub license](https://img.shields.io/github/license/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/blob/master/LICENSE)


> ⚠️️ **Current Status**: the Tianshou master branch is currently under heavy development,
> moving towards more features, improved interfaces, more documentation.
You can view the relevant issues in the corresponding
Expand Down Expand Up @@ -178,7 +179,7 @@ Find example scripts in the [test/](https://github.com/thu-ml/tianshou/blob/mast

<sup>(4): super fast APPO!</sup>

### High quality software engineering standard
### High Software Engineering Standards

| RL Platform | Documentation | Code Coverage | Type Hints | Last Update |
| ------------------------------------------------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------ | ----------------------------------------------------------------------------------------------------------------- |
Expand Down Expand Up @@ -232,8 +233,6 @@ We shall apply the deep Q network (DQN) learning algorithm using both APIs.

### High-Level API

The high-level API requires the extra package `argparse` (by adding
`--extras argparse`) to be installed.
To get started, we need some imports.

```python
Expand Down Expand Up @@ -332,11 +331,15 @@ Here's a run (with the training time cut short):
<img src="docs/_static/images/discrete_dqn_hl.gif">
</p>

Find many further applications of the high-level API in the `examples/` folder;
look for scripts ending with `_hl.py`.
Note that most of these examples require the extra package `argparse`
(install it by adding `--extras argparse` when invoking poetry).

### Procedural API

Let us now consider an analogous example in the procedural API.
Find the full script from which the snippets below were derived at [test/discrete/test_dqn.py](https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_dqn.py).
Find the full script in [examples/discrete/discrete_dqn.py](https://github.com/thu-ml/tianshou/blob/master/examples/discrete/discrete_dqn.py).

First, import some relevant packages:

Expand All @@ -357,32 +360,38 @@ gamma, n_step, target_freq = 0.9, 3, 320
buffer_size = 20000
eps_train, eps_test = 0.1, 0.05
step_per_epoch, step_per_collect = 10000, 10
logger = ts.utils.TensorboardLogger(SummaryWriter('log/dqn')) # TensorBoard is supported!
# For other loggers: https://tianshou.readthedocs.io/en/master/01_tutorials/05_logger.html
```

Initialize the logger:

```python
logger = ts.utils.TensorboardLogger(SummaryWriter('log/dqn'))
# For other loggers, see https://tianshou.readthedocs.io/en/master/01_tutorials/05_logger.html
```

Make environments:

```python
# you can also try with SubprocVectorEnv
# You can also try SubprocVectorEnv, which will use parallelization
train_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(train_num)])
test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(test_num)])
```

Define the network:
Create the network as well as its optimizer:

```python
from tianshou.utils.net.common import Net
# you can define other net by following the API:
# https://tianshou.readthedocs.io/en/master/01_tutorials/00_dqn.html#build-the-network

# Note: You can easily define other networks.
# See https://tianshou.readthedocs.io/en/master/01_tutorials/00_dqn.html#build-the-network
env = gym.make(task, render_mode="human")
state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n
net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128])
optim = torch.optim.Adam(net.parameters(), lr=lr)
```

Setup policy and collectors:
Set up the policy and collectors:

```python
policy = ts.policy.DQNPolicy(
Expand Down Expand Up @@ -418,14 +427,14 @@ result = ts.trainer.OffpolicyTrainer(
print(f"Finished training in {result.timing.total_time} seconds")
```

Save / load the trained policy (it's exactly the same as PyTorch `nn.module`):
Save/load the trained policy (it's exactly the same as loading a `torch.nn.module`):

```python
torch.save(policy.state_dict(), 'dqn.pth')
policy.load_state_dict(torch.load('dqn.pth'))
```

Watch the performance with 35 FPS:
Watch the agent with 35 FPS:

```python
policy.eval()
Expand All @@ -434,13 +443,13 @@ collector = ts.data.Collector(policy, env, exploration_noise=True)
collector.collect(n_episode=1, render=1 / 35)
```

Look at the result saved in tensorboard: (with bash script in your terminal)
Inspect the data saved in TensorBoard:

```bash
$ tensorboard --logdir log/dqn
```

You can check out the [documentation](https://tianshou.readthedocs.io) for advanced usage.
Please read the [documentation](https://tianshou.readthedocs.io) for advanced usage.

## Contributing

Expand Down
18 changes: 7 additions & 11 deletions examples/discrete/discrete_dqn.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from typing import cast

import gymnasium as gym
import torch
from torch.utils.tensorboard import SummaryWriter

import tianshou as ts
from tianshou.utils.space_info import SpaceInfo


def main() -> None:
Expand All @@ -16,22 +13,21 @@ def main() -> None:
buffer_size = 20000
eps_train, eps_test = 0.1, 0.05
step_per_epoch, step_per_collect = 10000, 10

logger = ts.utils.TensorboardLogger(SummaryWriter("log/dqn")) # TensorBoard is supported!
# For other loggers: https://tianshou.readthedocs.io/en/master/tutorials/logger.html
# For other loggers, see https://tianshou.readthedocs.io/en/master/tutorials/logger.html

# you can also try with SubprocVectorEnv
# You can also try SubprocVectorEnv, which will use parallelization
train_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(train_num)])
test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(test_num)])

from tianshou.utils.net.common import Net

# you can define other net by following the API:
# https://tianshou.readthedocs.io/en/master/tutorials/dqn.html#build-the-network
# Note: You can easily define other networks.
# See https://tianshou.readthedocs.io/en/master/01_tutorials/00_dqn.html#build-the-network
env = gym.make(task, render_mode="human")
env.action_space = cast(gym.spaces.Discrete, env.action_space)
space_info = SpaceInfo.from_env(env)
state_shape = space_info.observation_info.obs_shape
action_shape = space_info.action_info.action_shape
state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n
net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128])
optim = torch.optim.Adam(net.parameters(), lr=lr)

Expand Down

0 comments on commit 3ae2215

Please sign in to comment.