From 18ad17fb50d5d50b1db4da25daefc140c1b2274d Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Tue, 5 Nov 2024 21:25:17 +0000 Subject: [PATCH] Add setup_trainer --- crabs/tracker/track_video.py | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/crabs/tracker/track_video.py b/crabs/tracker/track_video.py index fa66e8ee..41497041 100644 --- a/crabs/tracker/track_video.py +++ b/crabs/tracker/track_video.py @@ -7,6 +7,7 @@ from pathlib import Path import cv2 +import lightning import numpy as np import torch import torchvision.transforms.v2 as transforms @@ -16,6 +17,7 @@ from crabs.detector.utils.detection import ( log_mlflow_metadata_as_info, set_mlflow_run_name, + setup_mlflow_logger, ) from crabs.tracker.evaluate_tracker import TrackerEvaluate from crabs.tracker.sort import Sort @@ -73,6 +75,33 @@ def __init__(self, args: argparse.Namespace) -> None: # Log MLflow information to screen log_mlflow_metadata_as_info(self) + def setup_trainer(self): + """Set up trainer object with logging for testing.""" + # Setup logger + mlf_logger = setup_mlflow_logger( + experiment_name=self.experiment_name, + run_name=self.run_name, + mlflow_folder=self.mlflow_folder, + cli_args=self.args, + ) + + # Add trained model section to MLflow hyperparameters + mlf_logger.log_hyperparams( + { + "trained_model/experiment_name": self.trained_model_expt_name, + "trained_model/run_name": self.trained_model_run_name, + "trained_model/ckpt_file": Path(self.trained_model_path).name, + } + ) + + # Add other unlogged information from init? + + # Return trainer linked to logger + return lightning.Trainer( + accelerator=self.accelerator, # lightning accelerators + logger=mlf_logger, + ) + def setup(self): """Load tracking config, trained model and input video path.""" with open(self.config_file) as f: @@ -334,10 +363,11 @@ def tracking_parse_args(args): parser.add_argument( "--experiment_name", type=str, + default="Inference", help=( "Name of the experiment in MLflow, under which the current run " "will be logged. " - "By default: _evaluation." + "By default: Inference." ), ) parser.add_argument(