Skip to content

Commit

Permalink
add evaluate_test_fn to BaseTrainer
Browse files Browse the repository at this point in the history
  • Loading branch information
anyongjin committed Aug 13, 2024
1 parent 4c1f779 commit 47b966a
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 6 deletions.
2 changes: 2 additions & 0 deletions tianshou/data/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class InfoStats(DataclassPPrintMixin):

gradient_step: int
"""The total gradient step."""
best_score: float
"""The best score over the test results."""
best_reward: float
"""The best reward over the test results."""
best_reward_std: float
Expand Down
38 changes: 32 additions & 6 deletions tianshou/trainer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class BaseTrainer(ABC):
:param test_fn: a hook called at the beginning of testing in each
epoch. It can be used to perform custom additional operations, with the
signature ``f(num_epoch: int, step_idx: int) -> None``.
:param evaluate_test_fn: Calculate the test batch performance score to determine whether it is the best model
:param save_best_fn: a hook called when the undiscounted average mean
reward in evaluation phase gets better, with the signature
``f(policy: BasePolicy) -> None``.
Expand Down Expand Up @@ -164,6 +165,7 @@ def __init__(
train_fn: Callable[[int, int], None] | None = None,
test_fn: Callable[[int, int | None], None] | None = None,
stop_fn: Callable[[float], bool] | None = None,
evaluate_test_fn: Callable[[CollectStats], float] | None = None,
save_best_fn: Callable[[BasePolicy], None] | None = None,
save_checkpoint_fn: Callable[[int, int, int], str] | None = None,
resume_from_log: bool = False,
Expand All @@ -185,6 +187,7 @@ def __init__(
self.logger = logger
self.start_time = time.time()
self.stat: defaultdict[str, MovAvg] = defaultdict(MovAvg)
self.best_score = 0.0
self.best_reward = 0.0
self.best_reward_std = 0.0
self.start_epoch = 0
Expand All @@ -210,6 +213,7 @@ def __init__(
self.train_fn = train_fn
self.test_fn = test_fn
self.stop_fn = stop_fn
self.evaluate_test_fn = evaluate_test_fn
self.save_best_fn = save_best_fn
self.save_checkpoint_fn = save_checkpoint_fn

Expand Down Expand Up @@ -273,6 +277,10 @@ def reset(self, reset_collectors: bool = True, reset_buffer: bool = False) -> No
test_result.returns_stat.mean,
test_result.returns_stat.std,
)
if self.evaluate_test_fn:
self.best_score = self.evaluate_test_fn(test_result)
else:
self.best_score = self.best_reward
if self.save_best_fn:
self.save_best_fn(self.policy)

Expand Down Expand Up @@ -351,6 +359,7 @@ def __next__(self) -> EpochStats:
start_time=self.start_time,
policy_update_time=self.policy_update_time,
gradient_step=self._gradient_step,
best_score=self.best_score,
best_reward=self.best_reward,
best_reward_std=self.best_reward_std,
train_collector=self.train_collector,
Expand Down Expand Up @@ -384,17 +393,29 @@ def test_step(self) -> tuple[CollectStats, bool]:
)
assert test_stat.returns_stat is not None # for mypy
rew, rew_std = test_stat.returns_stat.mean, test_stat.returns_stat.std
if self.best_epoch < 0 or self.best_reward < rew:
if self.evaluate_test_fn:
score = self.evaluate_test_fn(test_stat)
else:
score = float(rew)
if self.best_epoch < 0 or self.best_score < score:
self.best_score = score
self.best_epoch = self.epoch
self.best_reward = float(rew)
self.best_reward_std = rew_std
if self.save_best_fn:
self.save_best_fn(self.policy)
log_msg = (
f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f},"
f" best_reward: {self.best_reward:.6f} ± "
f"{self.best_reward_std:.6f} in #{self.best_epoch}"
)
if self.evaluate_test_fn:
log_msg = (
f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, score: {score:.6f},"
f" best_reward: {self.best_reward:.6f} ± "
f"{self.best_reward_std:.6f}, score: {self.best_score:.6f} in #{self.best_epoch}"
)
else:
log_msg = (
f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f},"
f" best_reward: {self.best_reward:.6f} ± "
f"{self.best_reward_std:.6f} in #{self.best_epoch}"
)
log.info(log_msg)
if self.verbose:
print(log_msg, flush=True)
Expand Down Expand Up @@ -506,6 +527,10 @@ def _update_best_reward_and_return_should_stop_training(
should_stop_training = True
self.best_reward = test_result.returns_stat.mean
self.best_reward_std = test_result.returns_stat.std
if self.evaluate_test_fn:
self.best_score = self.evaluate_test_fn(test_result)
else:
self.best_score = self.best_reward

return should_stop_training

Expand Down Expand Up @@ -562,6 +587,7 @@ def run(self, reset_prior_to_run: bool = True) -> InfoStats:
start_time=self.start_time,
policy_update_time=self.policy_update_time,
gradient_step=self._gradient_step,
best_score=self.best_score,
best_reward=self.best_reward,
best_reward_std=self.best_reward_std,
train_collector=self.train_collector,
Expand Down
2 changes: 2 additions & 0 deletions tianshou/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def gather_info(
start_time: float,
policy_update_time: float,
gradient_step: int,
best_score: float,
best_reward: float,
best_reward_std: float,
train_collector: BaseCollector | None = None,
Expand Down Expand Up @@ -75,6 +76,7 @@ def gather_info(

return InfoStats(
gradient_step=gradient_step,
best_score=best_score,
best_reward=best_reward,
best_reward_std=best_reward_std,
train_step=train_collector.collect_step if train_collector is not None else 0,
Expand Down

0 comments on commit 47b966a

Please sign in to comment.