From 1fe95721d23755e8e0c049f14d908334b4c4c37a Mon Sep 17 00:00:00 2001 From: Furkan-rgb <50831308+Furkan-rgb@users.noreply.github.com> Date: Sun, 3 Nov 2024 13:16:31 +0100 Subject: [PATCH] Update keras.py black formatter Formatted using black... --- optuna_integration/keras/keras.py | 31 +++---------------------------- 1 file changed, 3 insertions(+), 28 deletions(-) diff --git a/optuna_integration/keras/keras.py b/optuna_integration/keras/keras.py index 6cb7a2e2..3cf7184f 100644 --- a/optuna_integration/keras/keras.py +++ b/optuna_integration/keras/keras.py @@ -3,10 +3,8 @@ import warnings import optuna - from optuna_integration._imports import try_import - with try_import() as _imports: from keras.callbacks import Callback @@ -29,25 +27,13 @@ class KerasPruningCallback(Callback): An evaluation metric for pruning, e.g., ``val_loss`` and ``val_accuracy``. Please refer to `keras.Callback reference `_ for further details. - fold: - Current fold number in k-fold cross validation. Used to ensure steps continue - from the last epoch of the previous fold. interval: Check if trial should be pruned every n-th epoch. By default ``interval=1`` and pruning is performed after every epoch. Increase ``interval`` to run several epochs faster before applying pruning. """ - # Class variable to track the last step across folds - _last_step = -1 # Start at -1 so first epoch of first fold starts at 0 - - def __init__( - self, - trial: optuna.trial.Trial, - monitor: str, - fold: int = 0, - interval: int = 1 - ) -> None: + def __init__(self, trial: optuna.trial.Trial, monitor: str, interval: int = 1) -> None: super().__init__() _imports.check() @@ -55,9 +41,6 @@ def __init__( self._trial = trial self._monitor = monitor self._interval = interval - self._fold = fold - # For this fold, start steps after the last step used - self._step_offset = self.__class__._last_step + 1 def on_epoch_end(self, epoch: int, logs: dict[str, float] | None = None) -> None: if (epoch + 1) % self._interval != 0: @@ -65,7 +48,6 @@ def on_epoch_end(self, epoch: int, logs: dict[str, float] | None = None) -> None logs = logs or {} current_score = logs.get(self._monitor) - if current_score is None: message = ( "The metric '{}' is not in the evaluation logs for pruning. " @@ -73,14 +55,7 @@ def on_epoch_end(self, epoch: int, logs: dict[str, float] | None = None) -> None ) warnings.warn(message) return - - # Calculate current step by adding epoch to the offset - current_step = self._step_offset + epoch - # Update the class's last step tracker - self.__class__._last_step = current_step - - self._trial.report(float(current_score), step=current_step) - + self._trial.report(float(current_score), step=epoch) if self._trial.should_prune(): - message = "Trial was pruned at fold {}, epoch {}.".format(self._fold, epoch) + message = "Trial was pruned at epoch {}.".format(epoch) raise optuna.TrialPruned(message)