From a38e586b0c265f3f5e5ae6c72aa174fb45ecbc22 Mon Sep 17 00:00:00 2001 From: anyongjin Date: Wed, 14 Aug 2024 16:45:21 +0800 Subject: [PATCH] Support computing custom scores and terminating/saving based on them in BaseTrainer (#1202) This PR introduces a new concept into tianshou training: a `best_score`. It is computed from the appropriate `Stats` instance and always added to `InfoStats`. ## Breaking Changes: - `InfoStats` has a new non-optional field `best_score` ## Background Currently, tianshou uses the maximum average return to find the best model. But sometimes it may not meet user needs, for example, the average return only drops by 5%, but the standard deviation drops by 50%. The latter is generally considered to be more stable and better than the former. --- tianshou/data/stats.py | 2 ++ tianshou/trainer/base.py | 38 ++++++++++++++++++++++++++++++++------ tianshou/trainer/utils.py | 2 ++ 3 files changed, 36 insertions(+), 6 deletions(-) diff --git a/tianshou/data/stats.py b/tianshou/data/stats.py index 4685f5730..ed64a429d 100644 --- a/tianshou/data/stats.py +++ b/tianshou/data/stats.py @@ -56,6 +56,8 @@ class InfoStats(DataclassPPrintMixin): gradient_step: int """The total gradient step.""" + best_score: float + """The best score over the test results. The one with the highest score will be considered the best model.""" best_reward: float """The best reward over the test results.""" best_reward_std: float diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 242f2b028..a6679fa20 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -73,6 +73,8 @@ class BaseTrainer(ABC): :param test_fn: a hook called at the beginning of testing in each epoch. It can be used to perform custom additional operations, with the signature ``f(num_epoch: int, step_idx: int) -> None``. + :param compute_score_fn: Calculate the test batch performance score to + determine whether it is the best model, the mean reward will be used as score if not provided. :param save_best_fn: a hook called when the undiscounted average mean reward in evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> None``. @@ -164,6 +166,7 @@ def __init__( train_fn: Callable[[int, int], None] | None = None, test_fn: Callable[[int, int | None], None] | None = None, stop_fn: Callable[[float], bool] | None = None, + compute_score_fn: Callable[[CollectStats], float] | None = None, save_best_fn: Callable[[BasePolicy], None] | None = None, save_checkpoint_fn: Callable[[int, int, int], str] | None = None, resume_from_log: bool = False, @@ -185,6 +188,7 @@ def __init__( self.logger = logger self.start_time = time.time() self.stat: defaultdict[str, MovAvg] = defaultdict(MovAvg) + self.best_score = 0.0 self.best_reward = 0.0 self.best_reward_std = 0.0 self.start_epoch = 0 @@ -210,6 +214,14 @@ def __init__( self.train_fn = train_fn self.test_fn = test_fn self.stop_fn = stop_fn + self.compute_score_fn: Callable[[CollectStats], float] + if compute_score_fn is None: + + def compute_score_fn(stat: CollectStats) -> float: + assert stat.returns_stat is not None # for mypy + return stat.returns_stat.mean + + self.compute_score_fn = compute_score_fn self.save_best_fn = save_best_fn self.save_checkpoint_fn = save_checkpoint_fn @@ -273,6 +285,7 @@ def reset(self, reset_collectors: bool = True, reset_buffer: bool = False) -> No test_result.returns_stat.mean, test_result.returns_stat.std, ) + self.best_score = self.compute_score_fn(test_result) if self.save_best_fn: self.save_best_fn(self.policy) @@ -351,6 +364,7 @@ def __next__(self) -> EpochStats: start_time=self.start_time, policy_update_time=self.policy_update_time, gradient_step=self._gradient_step, + best_score=self.best_score, best_reward=self.best_reward, best_reward_std=self.best_reward_std, train_collector=self.train_collector, @@ -384,17 +398,27 @@ def test_step(self) -> tuple[CollectStats, bool]: ) assert test_stat.returns_stat is not None # for mypy rew, rew_std = test_stat.returns_stat.mean, test_stat.returns_stat.std - if self.best_epoch < 0 or self.best_reward < rew: + score = self.compute_score_fn(test_stat) + if self.best_epoch < 0 or self.best_score < score: + self.best_score = score self.best_epoch = self.epoch self.best_reward = float(rew) self.best_reward_std = rew_std if self.save_best_fn: self.save_best_fn(self.policy) - log_msg = ( - f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}," - f" best_reward: {self.best_reward:.6f} ± " - f"{self.best_reward_std:.6f} in #{self.best_epoch}" - ) + if score != rew: + # use custom score calculater + log_msg = ( + f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, score: {score:.6f}," + f" best_reward: {self.best_reward:.6f} ± " + f"{self.best_reward_std:.6f}, score: {self.best_score:.6f} in #{self.best_epoch}" + ) + else: + log_msg = ( + f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}," + f" best_reward: {self.best_reward:.6f} ± " + f"{self.best_reward_std:.6f} in #{self.best_epoch}" + ) log.info(log_msg) if self.verbose: print(log_msg, flush=True) @@ -506,6 +530,7 @@ def _update_best_reward_and_return_should_stop_training( should_stop_training = True self.best_reward = test_result.returns_stat.mean self.best_reward_std = test_result.returns_stat.std + self.best_score = self.compute_score_fn(test_result) return should_stop_training @@ -562,6 +587,7 @@ def run(self, reset_prior_to_run: bool = True) -> InfoStats: start_time=self.start_time, policy_update_time=self.policy_update_time, gradient_step=self._gradient_step, + best_score=self.best_score, best_reward=self.best_reward, best_reward_std=self.best_reward_std, train_collector=self.train_collector, diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index de730cee2..1f4369f72 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -42,6 +42,7 @@ def gather_info( start_time: float, policy_update_time: float, gradient_step: int, + best_score: float, best_reward: float, best_reward_std: float, train_collector: BaseCollector | None = None, @@ -75,6 +76,7 @@ def gather_info( return InfoStats( gradient_step=gradient_step, + best_score=best_score, best_reward=best_reward, best_reward_std=best_reward_std, train_step=train_collector.collect_step if train_collector is not None else 0,