Skip to content

Commit

Permalink
Fix tracking of different runs when using WandbLogger (#766)
Browse files Browse the repository at this point in the history
  • Loading branch information
nkaenzig authored Feb 27, 2025
1 parent 2fea425 commit d2b2698
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 22 deletions.
33 changes: 33 additions & 0 deletions src/eva/core/loggers/utils/wandb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# type: ignore
"""Utility functions for logging with Weights & Biases."""

from typing import Any, Dict

from loguru import logger


def rename_active_run(name: str) -> None:
"""Renames the current run."""
import wandb

if wandb.run:
wandb.run.name = name
wandb.run.save()
else:
logger.warning("No active wandb run found that could be renamed.")


def init_run(name: str, init_kwargs: Dict[str, Any]) -> None:
"""Initializes a new run. If there is an active run, it will be renamed and reused."""
import wandb

init_kwargs["name"] = name
rename_active_run(name)
wandb.init(**init_kwargs)


def finish_run() -> None:
"""Finish the current run."""
import wandb

wandb.finish()
13 changes: 8 additions & 5 deletions src/eva/core/trainers/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def run_evaluation_session(
base_trainer,
base_model,
datamodule,
run_id=f"run_{run_index}",
run_id=run_index,
verbose=not verbose,
)
recorder.update(validation_scores, test_scores)
Expand All @@ -51,7 +51,7 @@ def run_evaluation(
base_model: modules.ModelModule,
datamodule: datamodules.DataModule,
*,
run_id: str | None = None,
run_id: int | None = None,
verbose: bool = True,
) -> Tuple[_EVALUATE_OUTPUT, _EVALUATE_OUTPUT | None]:
"""Fits and evaluates a model out-of-place.
Expand All @@ -61,7 +61,6 @@ def run_evaluation(
base_model: The model module to use but not modify.
datamodule: The data module.
run_id: The run id to be appended to the output log directory.
If `None`, it will use the log directory of the trainer as is.
verbose: Whether to print the validation and test metrics
in the end of the training.
Expand All @@ -70,8 +69,12 @@ def run_evaluation(
"""
trainer, model = _utils.clone(base_trainer, base_model)
model.configure_model()
trainer.setup_log_dirs(run_id or "")
return fit_and_validate(trainer, model, datamodule, verbose=verbose)

trainer.init_logger_run(run_id)
results = fit_and_validate(trainer, model, datamodule, verbose=verbose)
trainer.finish_logger_run(run_id)

return results


def fit_and_validate(
Expand Down
49 changes: 32 additions & 17 deletions src/eva/core/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from eva.core import loggers as eva_loggers
from eva.core.data import datamodules
from eva.core.loggers.utils import wandb as wandb_utils
from eva.core.models import modules
from eva.core.trainers import _logging, functional

Expand Down Expand Up @@ -53,7 +54,7 @@ def __init__(
self._session_id: str = _logging.generate_session_id()
self._log_dir: str = self.default_log_dir

self.setup_log_dirs()
self.init_logger_run(0)

@property
def default_log_dir(self) -> str:
Expand All @@ -65,31 +66,45 @@ def default_log_dir(self) -> str:
def log_dir(self) -> str | None:
return self.strategy.broadcast(self._log_dir)

def setup_log_dirs(self, subdirectory: str = "") -> None:
"""Setups the logging directory of the trainer and experimental loggers in-place.
def init_logger_run(self, run_id: int | None) -> None:
"""Setup the loggers & log directories when starting a new run.
Args:
subdirectory: Whether to append a subdirectory to the output log.
run_id: The id of the current run.
"""
subdirectory = f"run_{run_id}" if run_id is not None else ""
self._log_dir = os.path.join(self.default_root_dir, self._session_id, subdirectory)

enabled_loggers = []
if isinstance(self.loggers, list) and len(self.loggers) > 0:
for logger in self.loggers:
if isinstance(logger, (pl_loggers.CSVLogger, pl_loggers.TensorBoardLogger)):
if not cloud_io._is_local_file_protocol(self.default_root_dir):
loguru.logger.warning(
f"Skipped {type(logger).__name__} as remote storage is not supported."
)
continue
else:
logger._root_dir = self.default_root_dir
logger._name = self._session_id
logger._version = subdirectory
enabled_loggers.append(logger)
for logger in self.loggers or []:
if isinstance(logger, (pl_loggers.CSVLogger, pl_loggers.TensorBoardLogger)):
if not cloud_io._is_local_file_protocol(self.default_root_dir):
loguru.logger.warning(
f"Skipped {type(logger).__name__} as remote storage is not supported."
)
continue
else:
logger._root_dir = self.default_root_dir
logger._name = self._session_id
logger._version = subdirectory
elif isinstance(logger, pl_loggers.WandbLogger):
task_name = self.default_root_dir.split("/")[-1]
run_name = os.getenv("WANDB_RUN_NAME", f"{task_name}_{self._session_id}")
wandb_utils.init_run(f"{run_name}_{run_id}", logger._wandb_init)
enabled_loggers.append(logger)

self._loggers = enabled_loggers or [eva_loggers.DummyLogger(self._log_dir)]

def finish_logger_run(self, run_id: int | None) -> None:
"""Finish the current run in the enabled loggers.
Args:
run_id: The id of the current run.
"""
for logger in self.loggers or []:
if isinstance(logger, pl_loggers.WandbLogger):
wandb_utils.finish_run()

def run_evaluation_session(
self,
model: modules.ModelModule,
Expand Down

0 comments on commit d2b2698

Please sign in to comment.