Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support computing custom scores and terminating/saving based on them in BaseTrainer #1202

Merged
merged 4 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tianshou/data/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class InfoStats(DataclassPPrintMixin):

gradient_step: int
"""The total gradient step."""
best_score: float
MischaPanch marked this conversation as resolved.
Show resolved Hide resolved
"""The best score over the test results. The one with the highest score will be considered the best model."""
best_reward: float
"""The best reward over the test results."""
best_reward_std: float
Expand Down
38 changes: 32 additions & 6 deletions tianshou/trainer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class BaseTrainer(ABC):
:param test_fn: a hook called at the beginning of testing in each
epoch. It can be used to perform custom additional operations, with the
signature ``f(num_epoch: int, step_idx: int) -> None``.
:param compute_score_fn: Calculate the test batch performance score to
determine whether it is the best model, the mean reward will be used as score if not provided.
:param save_best_fn: a hook called when the undiscounted average mean
reward in evaluation phase gets better, with the signature
``f(policy: BasePolicy) -> None``.
Expand Down Expand Up @@ -164,6 +166,7 @@ def __init__(
train_fn: Callable[[int, int], None] | None = None,
test_fn: Callable[[int, int | None], None] | None = None,
stop_fn: Callable[[float], bool] | None = None,
compute_score_fn: Callable[[CollectStats], float] | None = None,
save_best_fn: Callable[[BasePolicy], None] | None = None,
save_checkpoint_fn: Callable[[int, int, int], str] | None = None,
resume_from_log: bool = False,
Expand All @@ -185,6 +188,7 @@ def __init__(
self.logger = logger
self.start_time = time.time()
self.stat: defaultdict[str, MovAvg] = defaultdict(MovAvg)
self.best_score = 0.0
self.best_reward = 0.0
self.best_reward_std = 0.0
self.start_epoch = 0
Expand All @@ -210,6 +214,14 @@ def __init__(
self.train_fn = train_fn
self.test_fn = test_fn
self.stop_fn = stop_fn
self.compute_score_fn: Callable[[CollectStats], float]
if compute_score_fn is None:

def compute_score_fn(stat: CollectStats) -> float:
assert stat.returns_stat is not None # for mypy
return stat.returns_stat.mean

self.compute_score_fn = compute_score_fn
self.save_best_fn = save_best_fn
self.save_checkpoint_fn = save_checkpoint_fn

Expand Down Expand Up @@ -273,6 +285,7 @@ def reset(self, reset_collectors: bool = True, reset_buffer: bool = False) -> No
test_result.returns_stat.mean,
test_result.returns_stat.std,
)
self.best_score = self.compute_score_fn(test_result)
if self.save_best_fn:
self.save_best_fn(self.policy)

Expand Down Expand Up @@ -351,6 +364,7 @@ def __next__(self) -> EpochStats:
start_time=self.start_time,
policy_update_time=self.policy_update_time,
gradient_step=self._gradient_step,
best_score=self.best_score,
best_reward=self.best_reward,
best_reward_std=self.best_reward_std,
train_collector=self.train_collector,
Expand Down Expand Up @@ -384,17 +398,27 @@ def test_step(self) -> tuple[CollectStats, bool]:
)
assert test_stat.returns_stat is not None # for mypy
rew, rew_std = test_stat.returns_stat.mean, test_stat.returns_stat.std
if self.best_epoch < 0 or self.best_reward < rew:
score = self.compute_score_fn(test_stat)
if self.best_epoch < 0 or self.best_score < score:
self.best_score = score
self.best_epoch = self.epoch
self.best_reward = float(rew)
self.best_reward_std = rew_std
if self.save_best_fn:
self.save_best_fn(self.policy)
log_msg = (
f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f},"
f" best_reward: {self.best_reward:.6f} ± "
f"{self.best_reward_std:.6f} in #{self.best_epoch}"
)
if score != rew:
MischaPanch marked this conversation as resolved.
Show resolved Hide resolved
# use custom score calculater
log_msg = (
f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, score: {score:.6f},"
f" best_reward: {self.best_reward:.6f} ± "
f"{self.best_reward_std:.6f}, score: {self.best_score:.6f} in #{self.best_epoch}"
)
else:
log_msg = (
f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f},"
f" best_reward: {self.best_reward:.6f} ± "
f"{self.best_reward_std:.6f} in #{self.best_epoch}"
)
log.info(log_msg)
if self.verbose:
print(log_msg, flush=True)
Expand Down Expand Up @@ -506,6 +530,7 @@ def _update_best_reward_and_return_should_stop_training(
should_stop_training = True
self.best_reward = test_result.returns_stat.mean
self.best_reward_std = test_result.returns_stat.std
self.best_score = self.compute_score_fn(test_result)

return should_stop_training

Expand Down Expand Up @@ -562,6 +587,7 @@ def run(self, reset_prior_to_run: bool = True) -> InfoStats:
start_time=self.start_time,
policy_update_time=self.policy_update_time,
gradient_step=self._gradient_step,
best_score=self.best_score,
best_reward=self.best_reward,
best_reward_std=self.best_reward_std,
train_collector=self.train_collector,
Expand Down
2 changes: 2 additions & 0 deletions tianshou/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def gather_info(
start_time: float,
policy_update_time: float,
gradient_step: int,
best_score: float,
best_reward: float,
best_reward_std: float,
train_collector: BaseCollector | None = None,
Expand Down Expand Up @@ -75,6 +76,7 @@ def gather_info(

return InfoStats(
gradient_step=gradient_step,
best_score=best_score,
best_reward=best_reward,
best_reward_std=best_reward_std,
train_step=train_collector.collect_step if train_collector is not None else 0,
Expand Down
Loading