Skip to content

Commit

Permalink
Update keras.py black formatter
Browse files Browse the repository at this point in the history
Formatted using black...
  • Loading branch information
Furkan-rgb authored Nov 3, 2024
1 parent 82ee71b commit 1fe9572
Showing 1 changed file with 3 additions and 28 deletions.
31 changes: 3 additions & 28 deletions optuna_integration/keras/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
import warnings

import optuna

from optuna_integration._imports import try_import


with try_import() as _imports:
from keras.callbacks import Callback

Expand All @@ -29,58 +27,35 @@ class KerasPruningCallback(Callback):
An evaluation metric for pruning, e.g., ``val_loss`` and
``val_accuracy``. Please refer to `keras.Callback reference
<https://keras.io/callbacks/#callback>`_ for further details.
fold:
Current fold number in k-fold cross validation. Used to ensure steps continue
from the last epoch of the previous fold.
interval:
Check if trial should be pruned every n-th epoch. By default ``interval=1`` and
pruning is performed after every epoch. Increase ``interval`` to run several
epochs faster before applying pruning.
"""

# Class variable to track the last step across folds
_last_step = -1 # Start at -1 so first epoch of first fold starts at 0

def __init__(
self,
trial: optuna.trial.Trial,
monitor: str,
fold: int = 0,
interval: int = 1
) -> None:
def __init__(self, trial: optuna.trial.Trial, monitor: str, interval: int = 1) -> None:
super().__init__()

_imports.check()

self._trial = trial
self._monitor = monitor
self._interval = interval
self._fold = fold
# For this fold, start steps after the last step used
self._step_offset = self.__class__._last_step + 1

def on_epoch_end(self, epoch: int, logs: dict[str, float] | None = None) -> None:
if (epoch + 1) % self._interval != 0:
return

logs = logs or {}
current_score = logs.get(self._monitor)

if current_score is None:
message = (
"The metric '{}' is not in the evaluation logs for pruning. "
"Please make sure you set the correct metric name.".format(self._monitor)
)
warnings.warn(message)
return

# Calculate current step by adding epoch to the offset
current_step = self._step_offset + epoch
# Update the class's last step tracker
self.__class__._last_step = current_step

self._trial.report(float(current_score), step=current_step)

self._trial.report(float(current_score), step=epoch)
if self._trial.should_prune():
message = "Trial was pruned at fold {}, epoch {}.".format(self._fold, epoch)
message = "Trial was pruned at epoch {}.".format(epoch)
raise optuna.TrialPruned(message)

0 comments on commit 1fe9572

Please sign in to comment.