diff --git a/README.md b/README.md index 0238d9ecd..025bcd3c4 100644 --- a/README.md +++ b/README.md @@ -7,12 +7,6 @@ [![PyPI](https://img.shields.io/pypi/v/tianshou)](https://pypi.org/project/tianshou/) [![Conda](https://img.shields.io/conda/vn/conda-forge/tianshou)](https://github.com/conda-forge/tianshou-feedstock) [![Read the Docs](https://img.shields.io/readthedocs/tianshou)](https://tianshou.readthedocs.io/en/master) [![Read the Docs](https://img.shields.io/readthedocs/tianshou-docs-zh-cn?label=%E4%B8%AD%E6%96%87%E6%96%87%E6%A1%A3)](https://tianshou.readthedocs.io/zh/master/) [![Unittest](https://github.com/thu-ml/tianshou/actions/workflows/pytest.yml/badge.svg)](https://github.com/thu-ml/tianshou/actions) [![codecov](https://img.shields.io/codecov/c/gh/thu-ml/tianshou)](https://codecov.io/gh/thu-ml/tianshou) [![GitHub issues](https://img.shields.io/github/issues/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/issues) [![GitHub stars](https://img.shields.io/github/stars/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/stargazers) [![GitHub forks](https://img.shields.io/github/forks/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/network) [![GitHub license](https://img.shields.io/github/license/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/blob/master/LICENSE) -> ⚠️️ **Current Status**: the Tianshou master branch is currently under heavy development, -> moving towards more features, improved interfaces, more documentation. -You can view the relevant issues in the corresponding -> [milestone](https://github.com/thu-ml/tianshou/milestone/1) -> Stay tuned! (and expect breaking changes until the next major release) - **Tianshou** ([天授](https://baike.baidu.com/item/%E5%A4%A9%E6%8E%88)) is a reinforcement learning platform based on pure PyTorch and [Gymnasium](http://github.com/Farama-Foundation/Gymnasium). Unlike other reinforcement learning libraries, which may have complex codebases, unfriendly high-level APIs, or are not optimized for speed, Tianshou provides a high-performance, modularized framework and user-friendly interfaces for building deep reinforcement learning agents. One more aspect that sets Tianshou apart is its @@ -41,7 +35,7 @@ Supported algorithms include: - [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf) - [Randomized Ensembled Double Q-Learning (REDQ)](https://arxiv.org/pdf/2101.05982.pdf) - [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf) -- Vanilla Imitation Learning +- [Vanilla Imitation Learning](https://en.wikipedia.org/wiki/Apprenticeship_learning) - [Batch-Constrained deep Q-Learning (BCQ)](https://arxiv.org/pdf/1812.02900.pdf) - [Conservative Q-Learning (CQL)](https://arxiv.org/pdf/2006.04779.pdf) - [Twin Delayed DDPG with Behavior Cloning (TD3+BC)](https://arxiv.org/pdf/2106.06860.pdf) @@ -241,8 +235,9 @@ from tianshou.highlevel.env import ( from tianshou.highlevel.experiment import DQNExperimentBuilder, ExperimentConfig from tianshou.highlevel.params.policy_params import DQNParams from tianshou.highlevel.trainer import ( - TrainerEpochCallbackTestDQNSetEps, - TrainerEpochCallbackTrainDQNSetEps, + EpochTestCallbackDQNSetEps, + EpochTrainCallbackDQNSetEps, + EpochStopCallbackRewardThreshold ) ``` diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index c6fe6dd04..fc04a219a 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -190,6 +190,7 @@ def watch() -> None: sys.exit(0) # test train_collector and start filling replay buffer + train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = OffpolicyTrainer( diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 765463cd3..f237c5a33 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -83,7 +83,7 @@ def get_args() -> argparse.Namespace: return parser.parse_args() -def test_dqn(args: argparse.Namespace = get_args()) -> None: +def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_atari_env( args.task, args.seed, @@ -232,6 +232,7 @@ def watch() -> None: sys.exit(0) # test train_collector and start filling replay buffer + train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = OffpolicyTrainer( @@ -259,4 +260,4 @@ def watch() -> None: if __name__ == "__main__": - test_dqn(get_args()) + main(get_args()) diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index f616a6838..127a14b24 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -203,6 +203,7 @@ def watch() -> None: sys.exit(0) # test train_collector and start filling replay buffer + train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = OffpolicyTrainer( diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index 911069400..8b1625275 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -200,6 +200,7 @@ def watch() -> None: sys.exit(0) # test train_collector and start filling replay buffer + train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # trainer diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 612b54008..f1a89ef40 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -256,6 +256,7 @@ def watch() -> None: sys.exit(0) # test train_collector and start filling replay buffer + train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = OnpolicyTrainer( diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index 7d6330ee1..dfb96419e 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -194,6 +194,7 @@ def watch() -> None: sys.exit(0) # test train_collector and start filling replay buffer + train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = OffpolicyTrainer( diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index 86e7fe0e1..7b341c0a1 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -230,6 +230,7 @@ def watch() -> None: sys.exit(0) # test train_collector and start filling replay buffer + train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = OffpolicyTrainer( diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index d5edf1a9a..f06964c28 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -243,6 +243,7 @@ def watch() -> None: sys.exit(0) # test train_collector and start filling replay buffer + train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = OffpolicyTrainer( diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index db1b6b2c3..a321fb303 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -39,6 +39,20 @@ def _parse_reset_result(reset_result: tuple) -> tuple[tuple, dict, bool]: return reset_result, {}, contains_info +def get_space_dtype(obs_space: gym.spaces.Box) -> type[np.floating] | type[np.integer]: + obs_space_dtype: type[np.integer] | type[np.floating] + if np.issubdtype(obs_space.dtype, np.integer): + obs_space_dtype = np.integer + elif np.issubdtype(obs_space.dtype, np.floating): + obs_space_dtype = np.floating + else: + raise TypeError( + f"Unsupported observation space dtype: {obs_space.dtype}. " + f"This might be a bug in tianshou or gymnasium, please report it!", + ) + return obs_space_dtype + + class NoopResetEnv(gym.Wrapper): """Sample initial states by taking random number of no-ops on reset. @@ -199,12 +213,8 @@ def __init__(self, env: gym.Env) -> None: super().__init__(env) self.size = 84 obs_space = env.observation_space - obs_space_dtype: type[np.floating[Any]] | type[np.integer[Any]] - if np.issubdtype(type(obs_space.dtype), np.integer): - obs_space_dtype = np.integer - elif np.issubdtype(type(obs_space.dtype), np.floating): - obs_space_dtype = np.floating assert isinstance(obs_space, gym.spaces.Box) + obs_space_dtype = get_space_dtype(obs_space) self.observation_space = gym.spaces.Box( low=np.min(obs_space.low), high=np.max(obs_space.high), @@ -273,15 +283,11 @@ def __init__(self, env: gym.Env, n_frames: int) -> None: obs_space_shape = env.observation_space.shape assert obs_space_shape is not None shape = (n_frames, *obs_space_shape) - assert isinstance(env.observation_space, gym.spaces.Box) - obs_space_dtype: type[np.floating[Any]] | type[np.integer[Any]] - if np.issubdtype(type(obs_space.dtype), np.integer): - obs_space_dtype = np.integer - elif np.issubdtype(type(obs_space.dtype), np.floating): - obs_space_dtype = np.floating + assert isinstance(obs_space, gym.spaces.Box) + obs_space_dtype = get_space_dtype(obs_space) self.observation_space = gym.spaces.Box( - low=np.min(env.observation_space.low), - high=np.max(env.observation_space.high), + low=np.min(obs_space.low), + high=np.max(obs_space.high), shape=shape, dtype=obs_space_dtype, ) diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index ad53b16da..96e61b612 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -92,6 +92,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: ) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) + train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, args.task, "dqn") diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index f52f6d5c1..8b1e8ca8d 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -117,6 +117,7 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: ) test_collector = Collector(policy, test_envs, exploration_noise=False) # policy.set_eps(1) + train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # log current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 47ba9d102..9e5db5833 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -94,6 +94,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: ) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) + train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, args.task, "dqn") diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index be6594aa7..cd6ceec89 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -213,6 +213,7 @@ def compute_reward_fn(ag: np.ndarray, g: np.ndarray) -> np.ndarray: ) train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs) + train_collector.reset() train_collector.collect(n_step=args.start_timesteps, random=True) def save_best_fn(policy: BasePolicy) -> None: diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index ceac47604..b2a40878b 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -122,6 +122,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: buffer = ReplayBuffer(args.buffer_size) train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs) + train_collector.reset() train_collector.collect(n_step=args.start_timesteps, random=True) # log diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index b300e498a..ae46b220c 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -150,6 +150,7 @@ def linear(x: int, y: int) -> EnsembleLinear: buffer = ReplayBuffer(args.buffer_size) train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs) + train_collector.reset() train_collector.collect(n_step=args.start_timesteps, random=True) # log diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index 2058a71e9..a0bd567ff 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -144,6 +144,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: buffer = ReplayBuffer(args.buffer_size) train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs) + train_collector.reset() train_collector.collect(n_step=args.start_timesteps, random=True) # log diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index 30e7539c1..6b6dfdc8c 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -142,6 +142,7 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: buffer = ReplayBuffer(args.buffer_size) train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs) + train_collector.reset() train_collector.collect(n_step=args.start_timesteps, random=True) # log diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index 62daaf64f..4211585af 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -196,6 +196,7 @@ def watch() -> None: sys.exit(0) # test train_collector and start filling replay buffer + train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = OffpolicyTrainer( diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index 7476d4f26..f5abf0b6f 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -258,6 +258,7 @@ def watch() -> None: sys.exit(0) # test train_collector and start filling replay buffer + train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = OnpolicyTrainer( diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 0ce219e75..b45ea00c5 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -557,7 +557,7 @@ def test_batch_standard_compatibility() -> None: batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]), b=Batch(), c=np.array([5.0, 6.0])) batch_mean = np.mean(batch) assert isinstance(batch_mean, Batch) # type: ignore # mypy doesn't know but it works, cf. `batch.rst` - assert sorted(batch_mean.keys()) == ["a", "b", "c"] # type: ignore + assert sorted(batch_mean.get_keys()) == ["a", "b", "c"] # type: ignore with pytest.raises(TypeError): len(batch_mean) assert np.all(batch_mean.a == np.mean(batch.a, axis=0)) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 31265f664..5488ff365 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -1379,11 +1379,14 @@ def test_custom_key() -> None: buffer.add(batch) sampled_batch, _ = buffer.sample(1) # Check if they have the same keys - assert set(batch.keys()) == set( - sampled_batch.keys(), - ), "Batches have different keys: {} and {}".format(set(batch.keys()), set(sampled_batch.keys())) + assert set(batch.get_keys()) == set( + sampled_batch.get_keys(), + ), "Batches have different keys: {} and {}".format( + set(batch.get_keys()), + set(sampled_batch.get_keys()), + ) # Compare the values for each key - for key in batch.keys(): + for key in batch.get_keys(): if isinstance(batch.__dict__[key], np.ndarray) and isinstance( sampled_batch.__dict__[key], np.ndarray, diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index fd5b15a9f..d13b03d85 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -36,7 +36,7 @@ def get_args() -> argparse.Namespace: parser.add_argument("--alpha", type=float, default=0.2) parser.add_argument("--auto-alpha", type=int, default=1) parser.add_argument("--alpha-lr", type=float, default=3e-4) - parser.add_argument("--epoch", type=int, default=5) + parser.add_argument("--epoch", type=int, default=10) parser.add_argument("--step-per-epoch", type=int, default=24000) parser.add_argument("--il-step-per-epoch", type=int, default=500) parser.add_argument("--step-per-collect", type=int, default=10) diff --git a/test/discrete/test_bdq.py b/test/discrete/test_bdq.py index c3f6afe3a..1089d4ba0 100644 --- a/test/discrete/test_bdq.py +++ b/test/discrete/test_bdq.py @@ -115,7 +115,8 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: ) test_collector = Collector(policy, test_envs, exploration_noise=False) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.training_num) def train_fn(epoch: int, env_step: int) -> None: # exp decay eps = max(args.eps_train * (1 - args.eps_decay) ** env_step, args.eps_test) diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 483aca9c6..4d25d430b 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -119,7 +119,8 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, args.task, "c51") writer = SummaryWriter(log_path) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 6c588839f..b62a93c3f 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -110,7 +110,8 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, args.task, "dqn") writer = SummaryWriter(log_path) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 8bca5c131..5c24518bb 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -94,7 +94,8 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None: # the stack_num is for RNN training: sample framestack obs test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, args.task, "drqn") writer = SummaryWriter(log_path) diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index f1af574f7..8ff9eeb7a 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -127,7 +127,8 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, args.task, "fqf") writer = SummaryWriter(log_path) diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index 87e7398b9..765bbf9bd 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -123,7 +123,8 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None: train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, args.task, "iqn") writer = SummaryWriter(log_path) diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 879717a75..6485637e8 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -112,7 +112,8 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, args.task, "qrdqn") writer = SummaryWriter(log_path) diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index c7035345e..ff4ef1c1e 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -127,7 +127,8 @@ def noisy_linear(x: int, y: int) -> NoisyLinear: train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, args.task, "rainbow") writer = SummaryWriter(log_path) diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index 9ca9c7055..9a4206e18 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -153,7 +153,8 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, args.task, "dqn_icm") writer = SummaryWriter(log_path) diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index e8411b2dd..93877944e 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -118,7 +118,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: test_collector = Collector(policy, test_envs, exploration_noise=True) test_collector.reset() # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) + train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, args.task, "qrdqn") writer = SummaryWriter(log_path) diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index 7b3fb4dfc..abd0c889a 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -135,7 +135,8 @@ def train_agent( exploration_noise=True, ) test_collector = Collector(policy, test_envs, exploration_noise=True) - train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, "pistonball", "dqn") writer = SummaryWriter(log_path) diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index da580f358..7ed631912 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -171,7 +171,8 @@ def train_agent( ) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, "tic_tac_toe", "dqn") writer = SummaryWriter(log_path) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 4eee0cd81..95b25f741 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -1,6 +1,6 @@ import pprint import warnings -from collections.abc import Collection, Iterable, Iterator, Sequence +from collections.abc import Collection, Iterable, Iterator, KeysView, Sequence from copy import deepcopy from numbers import Number from types import EllipsisType @@ -186,8 +186,8 @@ def alloc_by_keys_diff( This mainly is an internal method, use it only if you know what you are doing. """ - for key in batch.keys(): - if key in meta.keys(): + for key in batch.get_keys(): + if key in meta.get_keys(): if isinstance(meta[key], Batch) and isinstance(batch[key], Batch): alloc_by_keys_diff(meta[key], batch[key], size, stack) elif isinstance(meta[key], Batch) and meta[key].is_empty(): @@ -450,6 +450,9 @@ def to_dict(self, recurse: bool = False) -> dict[str, Any]: result[k] = v return result + def get_keys(self) -> KeysView: + return self.__dict__.keys() + def to_list_of_dicts(self) -> list[dict[str, Any]]: return [entry.to_dict() for entry in self] diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/base.py index 53f9bd8eb..a34964f5a 100644 --- a/tianshou/data/buffer/base.py +++ b/tianshou/data/buffer/base.py @@ -253,12 +253,12 @@ def add( """ # preprocess batch new_batch = Batch() - for key in batch.keys(): + for key in batch.get_keys(): new_batch.__dict__[key] = batch[key] batch = new_batch batch.__dict__["done"] = np.logical_or(batch.terminated, batch.truncated) assert {"obs", "act", "rew", "terminated", "truncated", "done"}.issubset( - batch.keys(), + batch.get_keys(), ) # important to do after preprocess batch stacked_batch = buffer_ids is not None if stacked_batch: diff --git a/tianshou/data/buffer/manager.py b/tianshou/data/buffer/manager.py index a495b0ada..90480257a 100644 --- a/tianshou/data/buffer/manager.py +++ b/tianshou/data/buffer/manager.py @@ -127,11 +127,11 @@ def add( """ # preprocess batch new_batch = Batch() - for key in set(self._reserved_keys).intersection(batch.keys()): + for key in set(self._reserved_keys).intersection(batch.get_keys()): new_batch.__dict__[key] = batch[key] batch = new_batch batch.__dict__["done"] = np.logical_or(batch.terminated, batch.truncated) - assert {"obs", "act", "rew", "terminated", "truncated", "done"}.issubset(batch.keys()) + assert {"obs", "act", "rew", "terminated", "truncated", "done"}.issubset(batch.get_keys()) if self._save_only_last_obs: batch.obs = batch.obs[:, -1] if not self._save_obs_next: diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 751fedfb2..345d50b03 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -330,10 +330,7 @@ def collect( :param random: whether to use random policy for collecting data. :param render: the sleep time between rendering consecutive frames. :param no_grad: whether to retain gradient in policy.forward(). - :param reset_before_collect: whether to reset the environment before - collecting data. - It has only an effect if n_episode is not None, i.e. - if one wants to collect a fixed number of episodes. + :param reset_before_collect: whether to reset the environment before collecting data. (The collector needs the initial obs and info to function properly.) :param gym_reset_kwargs: extra keyword arguments to pass into the environment's reset function. Only used if reset_before_collect is True. diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 17f0550fc..87cbd4927 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -311,7 +311,8 @@ def _watch_agent( ) -> None: policy.eval() collector = Collector(policy, env) - result = collector.collect(n_episode=num_episodes, render=render, reset_before_collect=True) + collector.reset() + result = collector.collect(n_episode=num_episodes, render=render) assert result.returns_stat is not None # for mypy assert result.lens_stat is not None # for mypy log.info( diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index 7a7be5888..81cfe0a6d 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -225,7 +225,7 @@ def forward( # type: ignore results.append((False, np.array([-1]), Batch(), Batch(), Batch())) continue tmp_batch = batch[agent_index] - if "rew" in tmp_batch.keys() and isinstance(tmp_batch.rew, np.ndarray): + if "rew" in tmp_batch.get_keys() and isinstance(tmp_batch.rew, np.ndarray): # reward can be empty Batch (after initial reset) or nparray. tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent_id]] if not hasattr(tmp_batch.obs, "mask"):