Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support computing custom scores and terminating/saving based on them in BaseTrainer #1202

Merged
merged 4 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tianshou/data/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class InfoStats(DataclassPPrintMixin):

gradient_step: int
"""The total gradient step."""
best_score: float
MischaPanch marked this conversation as resolved.
Show resolved Hide resolved
"""The best score over the test results."""
best_reward: float
"""The best reward over the test results."""
best_reward_std: float
Expand Down
38 changes: 32 additions & 6 deletions tianshou/trainer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class BaseTrainer(ABC):
:param test_fn: a hook called at the beginning of testing in each
epoch. It can be used to perform custom additional operations, with the
signature ``f(num_epoch: int, step_idx: int) -> None``.
:param evaluate_test_fn: Calculate the test batch performance score to determine whether it is the best model
MischaPanch marked this conversation as resolved.
Show resolved Hide resolved
:param save_best_fn: a hook called when the undiscounted average mean
reward in evaluation phase gets better, with the signature
``f(policy: BasePolicy) -> None``.
Expand Down Expand Up @@ -164,6 +165,7 @@ def __init__(
train_fn: Callable[[int, int], None] | None = None,
test_fn: Callable[[int, int | None], None] | None = None,
stop_fn: Callable[[float], bool] | None = None,
evaluate_test_fn: Callable[[CollectStats], float] | None = None,
MischaPanch marked this conversation as resolved.
Show resolved Hide resolved
save_best_fn: Callable[[BasePolicy], None] | None = None,
save_checkpoint_fn: Callable[[int, int, int], str] | None = None,
resume_from_log: bool = False,
Expand All @@ -185,6 +187,7 @@ def __init__(
self.logger = logger
self.start_time = time.time()
self.stat: defaultdict[str, MovAvg] = defaultdict(MovAvg)
self.best_score = 0.0
self.best_reward = 0.0
self.best_reward_std = 0.0
self.start_epoch = 0
Expand All @@ -210,6 +213,7 @@ def __init__(
self.train_fn = train_fn
self.test_fn = test_fn
self.stop_fn = stop_fn
self.evaluate_test_fn = evaluate_test_fn
self.save_best_fn = save_best_fn
self.save_checkpoint_fn = save_checkpoint_fn

Expand Down Expand Up @@ -273,6 +277,10 @@ def reset(self, reset_collectors: bool = True, reset_buffer: bool = False) -> No
test_result.returns_stat.mean,
test_result.returns_stat.std,
)
if self.evaluate_test_fn:
MischaPanch marked this conversation as resolved.
Show resolved Hide resolved
self.best_score = self.evaluate_test_fn(test_result)
else:
self.best_score = self.best_reward
if self.save_best_fn:
self.save_best_fn(self.policy)

Expand Down Expand Up @@ -351,6 +359,7 @@ def __next__(self) -> EpochStats:
start_time=self.start_time,
policy_update_time=self.policy_update_time,
gradient_step=self._gradient_step,
best_score=self.best_score,
best_reward=self.best_reward,
best_reward_std=self.best_reward_std,
train_collector=self.train_collector,
Expand Down Expand Up @@ -384,17 +393,29 @@ def test_step(self) -> tuple[CollectStats, bool]:
)
assert test_stat.returns_stat is not None # for mypy
rew, rew_std = test_stat.returns_stat.mean, test_stat.returns_stat.std
if self.best_epoch < 0 or self.best_reward < rew:
if self.evaluate_test_fn:
MischaPanch marked this conversation as resolved.
Show resolved Hide resolved
score = self.evaluate_test_fn(test_stat)
else:
score = float(rew)
if self.best_epoch < 0 or self.best_score < score:
self.best_score = score
self.best_epoch = self.epoch
self.best_reward = float(rew)
self.best_reward_std = rew_std
if self.save_best_fn:
self.save_best_fn(self.policy)
log_msg = (
f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f},"
f" best_reward: {self.best_reward:.6f} ± "
f"{self.best_reward_std:.6f} in #{self.best_epoch}"
)
if self.evaluate_test_fn:
MischaPanch marked this conversation as resolved.
Show resolved Hide resolved
log_msg = (
f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, score: {score:.6f},"
f" best_reward: {self.best_reward:.6f} ± "
f"{self.best_reward_std:.6f}, score: {self.best_score:.6f} in #{self.best_epoch}"
)
else:
log_msg = (
f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f},"
f" best_reward: {self.best_reward:.6f} ± "
f"{self.best_reward_std:.6f} in #{self.best_epoch}"
)
log.info(log_msg)
if self.verbose:
print(log_msg, flush=True)
Expand Down Expand Up @@ -506,6 +527,10 @@ def _update_best_reward_and_return_should_stop_training(
should_stop_training = True
self.best_reward = test_result.returns_stat.mean
self.best_reward_std = test_result.returns_stat.std
if self.evaluate_test_fn:
self.best_score = self.evaluate_test_fn(test_result)
else:
self.best_score = self.best_reward

return should_stop_training

Expand Down Expand Up @@ -562,6 +587,7 @@ def run(self, reset_prior_to_run: bool = True) -> InfoStats:
start_time=self.start_time,
policy_update_time=self.policy_update_time,
gradient_step=self._gradient_step,
best_score=self.best_score,
best_reward=self.best_reward,
best_reward_std=self.best_reward_std,
train_collector=self.train_collector,
Expand Down
2 changes: 2 additions & 0 deletions tianshou/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def gather_info(
start_time: float,
policy_update_time: float,
gradient_step: int,
best_score: float,
best_reward: float,
best_reward_std: float,
train_collector: BaseCollector | None = None,
Expand Down Expand Up @@ -75,6 +76,7 @@ def gather_info(

return InfoStats(
gradient_step=gradient_step,
best_score=best_score,
best_reward=best_reward,
best_reward_std=best_reward_std,
train_step=train_collector.collect_step if train_collector is not None else 0,
Expand Down
Loading